diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 8c25c57740d5..3ca23cfd9a58 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -27,6 +27,7 @@ use arrow::datatypes::{ ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, }; +use arrow::datatypes::{Field, FieldRef}; use arrow::error::ArrowError; use datafusion_common::types::{ NativeType, logical_float32, logical_float64, logical_int32, @@ -34,10 +35,11 @@ use datafusion_common::types::{ use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignature, TypeSignatureClass, Volatility, + Coercion, ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; +use std::sync::Arc; #[user_doc( doc_section(label = "Math Functions"), @@ -117,15 +119,74 @@ impl ScalarUDFImpl for RoundFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0].clone() { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let input_field = &args.arg_fields[0]; + let input_type = input_field.data_type(); + + // Get decimal_places from scalar_arguments + // If dp is not a constant scalar, we must keep the original scale because + // we can't determine a single output scale for varying per-row dp values. + let (decimal_places, dp_is_scalar): (i32, bool) = + if args.scalar_arguments.len() > 1 { + match args.scalar_arguments[1] { + Some(ScalarValue::Int32(Some(v))) => (*v, true), + Some(ScalarValue::Int64(Some(v))) => (*v as i32, true), + _ => (0, false), // dp is a column or null - can't determine scale + } + } else { + (0, true) // No dp argument means default to 0 + }; + + // Calculate return type based on input type + // For decimals: reduce scale to decimal_places (reclaims precision for integer part) + // This matches Spark/DuckDB behavior where ROUND adjusts the scale + // BUT only if dp is a constant - otherwise keep original scale + let return_type = match input_type { Float32 => Float32, - dt @ Decimal128(_, _) - | dt @ Decimal256(_, _) - | dt @ Decimal32(_, _) - | dt @ Decimal64(_, _) => dt, + Decimal32(precision, scale) => { + if dp_is_scalar { + let new_scale = (*scale).min(decimal_places.max(0) as i8); + Decimal32(*precision, new_scale) + } else { + Decimal32(*precision, *scale) + } + } + Decimal64(precision, scale) => { + if dp_is_scalar { + let new_scale = (*scale).min(decimal_places.max(0) as i8); + Decimal64(*precision, new_scale) + } else { + Decimal64(*precision, *scale) + } + } + Decimal128(precision, scale) => { + if dp_is_scalar { + let new_scale = (*scale).min(decimal_places.max(0) as i8); + Decimal128(*precision, new_scale) + } else { + Decimal128(*precision, *scale) + } + } + Decimal256(precision, scale) => { + if dp_is_scalar { + let new_scale = (*scale).min(decimal_places.max(0) as i8); + Decimal256(*precision, new_scale) + } else { + Decimal256(*precision, *scale) + } + } _ => Float64, - }) + }; + + Ok(Arc::new(Field::new( + self.name(), + return_type, + input_field.is_nullable(), + ))) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("use return_field_from_args instead") } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -141,7 +202,6 @@ impl ScalarUDFImpl for RoundFunc { &default_decimal_places }; - // Scalar fast path for float and decimal types - avoid array conversion overhead if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) = (&args.args[0], decimal_places) { @@ -169,27 +229,32 @@ impl ScalarUDFImpl for RoundFunc { Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) } ScalarValue::Decimal128(Some(v), precision, scale) => { - let rounded = round_decimal(*v, *scale, dp)?; + // Reduce scale to reclaim integer precision + let new_scale = (*scale).min(dp.max(0) as i8); + let rounded = round_decimal(*v, *scale, new_scale, dp)?; let scalar = - ScalarValue::Decimal128(Some(rounded), *precision, *scale); + ScalarValue::Decimal128(Some(rounded), *precision, new_scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal256(Some(v), precision, scale) => { - let rounded = round_decimal(*v, *scale, dp)?; + let new_scale = (*scale).min(dp.max(0) as i8); + let rounded = round_decimal(*v, *scale, new_scale, dp)?; let scalar = - ScalarValue::Decimal256(Some(rounded), *precision, *scale); + ScalarValue::Decimal256(Some(rounded), *precision, new_scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal64(Some(v), precision, scale) => { - let rounded = round_decimal(*v, *scale, dp)?; + let new_scale = (*scale).min(dp.max(0) as i8); + let rounded = round_decimal(*v, *scale, new_scale, dp)?; let scalar = - ScalarValue::Decimal64(Some(rounded), *precision, *scale); + ScalarValue::Decimal64(Some(rounded), *precision, new_scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal32(Some(v), precision, scale) => { - let rounded = round_decimal(*v, *scale, dp)?; + let new_scale = (*scale).min(dp.max(0) as i8); + let rounded = round_decimal(*v, *scale, new_scale, dp)?; let scalar = - ScalarValue::Decimal32(Some(rounded), *precision, *scale); + ScalarValue::Decimal32(Some(rounded), *precision, new_scale); Ok(ColumnarValue::Scalar(scalar)) } _ => { @@ -200,7 +265,12 @@ impl ScalarUDFImpl for RoundFunc { } } } else { - round_columnar(&args.args[0], decimal_places, args.number_rows) + round_columnar( + &args.args[0], + decimal_places, + args.number_rows, + args.return_type(), + ) } } @@ -228,13 +298,14 @@ fn round_columnar( value: &ColumnarValue, decimal_places: &ColumnarValue, number_rows: usize, + return_type: &DataType, ) -> Result { let value_array = value.to_array(number_rows)?; let both_scalars = matches!(value, ColumnarValue::Scalar(_)) && matches!(decimal_places, ColumnarValue::Scalar(_)); - let arr: ArrayRef = match value_array.data_type() { - Float64 => { + let arr: ArrayRef = match (value_array.data_type(), return_type) { + (Float64, _) => { let result = calculate_binary_math::( value_array.as_ref(), decimal_places, @@ -242,7 +313,7 @@ fn round_columnar( )?; result as _ } - Float32 => { + (Float32, _) => { let result = calculate_binary_math::( value_array.as_ref(), decimal_places, @@ -250,7 +321,8 @@ fn round_columnar( )?; result as _ } - Decimal32(precision, scale) => { + (Decimal32(_, scale), Decimal32(precision, new_scale)) => { + // reduce scale to reclaim integer precision let result = calculate_binary_decimal_math::< Decimal32Type, Int32Type, @@ -259,13 +331,13 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), + |v, dp| round_decimal(v, *scale, *new_scale, dp), *precision, - *scale, + *new_scale, )?; result as _ } - Decimal64(precision, scale) => { + (Decimal64(_, scale), Decimal64(precision, new_scale)) => { let result = calculate_binary_decimal_math::< Decimal64Type, Int32Type, @@ -274,13 +346,13 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), + |v, dp| round_decimal(v, *scale, *new_scale, dp), *precision, - *scale, + *new_scale, )?; result as _ } - Decimal128(precision, scale) => { + (Decimal128(_, scale), Decimal128(precision, new_scale)) => { let result = calculate_binary_decimal_math::< Decimal128Type, Int32Type, @@ -289,13 +361,13 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), + |v, dp| round_decimal(v, *scale, *new_scale, dp), *precision, - *scale, + *new_scale, )?; result as _ } - Decimal256(precision, scale) => { + (Decimal256(_, scale), Decimal256(precision, new_scale)) => { let result = calculate_binary_decimal_math::< Decimal256Type, Int32Type, @@ -304,13 +376,13 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), + |v, dp| round_decimal(v, *scale, *new_scale, dp), *precision, - *scale, + *new_scale, )?; result as _ } - other => exec_err!("Unsupported data type {other:?} for function round")?, + (other, _) => exec_err!("Unsupported data type {other:?} for function round")?, }; if both_scalars { @@ -334,10 +406,11 @@ where fn round_decimal( value: V, - scale: i8, + input_scale: i8, + output_scale: i8, decimal_places: i32, ) -> Result { - let diff = i64::from(scale) - i64::from(decimal_places); + let diff = i64::from(input_scale) - i64::from(decimal_places); if diff <= 0 { return Ok(value); } @@ -358,7 +431,7 @@ fn round_decimal( let factor = ten.pow_checked(diff).map_err(|_| { ArrowError::ComputeError(format!( - "Overflow while rounding decimal with scale {scale} and decimal places {decimal_places}" + "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}" )) })?; @@ -377,9 +450,40 @@ fn round_decimal( })?; } - quotient - .mul_checked(factor) - .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into())) + // Determine how to scale the result based on output_scale vs computed scale + // computed_scale = max(0, min(input_scale, decimal_places)) + let computed_scale = if decimal_places >= 0 { + (input_scale as i32).min(decimal_places).max(0) as i8 + } else { + 0 + }; + + if output_scale == computed_scale { + // scale reduction, return quotient directly (or shifted for negative dp) + if decimal_places >= 0 { + Ok(quotient) + } else { + // For negative decimal_places, multiply by 10^(-decimal_places) to shift left + let neg_dp: u32 = (-decimal_places).try_into().map_err(|_| { + ArrowError::ComputeError(format!( + "Invalid negative decimal places: {decimal_places}" + )) + })?; + let shift_factor = ten.pow_checked(neg_dp).map_err(|_| { + ArrowError::ComputeError(format!( + "Overflow computing shift factor for decimal places {decimal_places}" + )) + })?; + quotient.mul_checked(shift_factor).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding decimal".into()) + }) + } + } else { + // Keep original scale behavior: multiply back by factor + quotient.mul_checked(factor).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding decimal".into()) + }) + } } #[cfg(test)] @@ -397,12 +501,14 @@ mod test { decimal_places: Option, ) -> Result { let number_rows = value.len(); + let return_type = value.data_type().clone(); let value = ColumnarValue::Array(value); let decimal_places = decimal_places .map(ColumnarValue::Array) .unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0)))); - let result = super::round_columnar(&value, &decimal_places, number_rows)?; + let result = + super::round_columnar(&value, &decimal_places, number_rows, &return_type)?; match result { ColumnarValue::Array(array) => Ok(array), ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1), diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index f53f4939299c..eca2c88bb5f8 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -782,7 +782,7 @@ query TR select arrow_typeof(round(173975140545.855, 2)), round(173975140545.855, 2); ---- -Decimal128(15, 3) 173975140545.86 +Decimal128(15, 2) 173975140545.86 # smoke test for decimal parsing query RT diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index b0307c4630e2..58b610867f87 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -931,11 +931,12 @@ query error Arrow error: Cast error: Can't cast value 2147483649 to type Int32 select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14, 2147483649); # round decimal should not cast to float +# scale reduces to match decimal_places query TR select arrow_typeof(round('173975140545.855'::decimal(38,10), 2)), round('173975140545.855'::decimal(38,10), 2); ---- -Decimal128(38, 10) 173975140545.86 +Decimal128(38, 2) 173975140545.86 # round decimal ties away from zero query RRRR @@ -951,15 +952,31 @@ query TR select arrow_typeof(round('12345.55'::decimal(10,2), -1)), round('12345.55'::decimal(10,2), -1); ---- -Decimal128(10, 2) 12350 +Decimal128(10, 0) 12350 # round decimal256 keeps decimals query TR select arrow_typeof(round('1234.5678'::decimal(50,4), 2)), round('1234.5678'::decimal(50,4), 2); ---- -Decimal256(50, 4) 1234.57 +Decimal256(50, 2) 1234.57 +# round decimal with carry-over (reduce scale) +# Scale reduces from 1 to 0, allowing extra digit for carry-over +query TRRR +select arrow_typeof(round('999.9'::decimal(4,1))), + round('999.9'::decimal(4,1)), + round('-999.9'::decimal(4,1)), + round('99.99'::decimal(4,2)); +---- +Decimal128(4, 0) 1000 -1000 100 + +# round decimal at max precision now works (scale reduction handles overflow) +query TR +select arrow_typeof(round('9999999999999999999999999999999999999.9'::decimal(38,1))), + round('9999999999999999999999999999999999999.9'::decimal(38,1)); +---- +Decimal128(38, 0) 10000000000000000000000000000000000000 ## signum