From 811e7cd5bfc9571bef7b635f0b30030c1dcf424f Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Wed, 21 Jan 2026 16:06:30 +0530 Subject: [PATCH 1/5] fix: increase ROUND decimal precision to prevent overflow truncation --- datafusion/functions/src/math/round.rs | 57 ++++++++++++++----- .../sqllogictest/test_files/decimal.slt | 2 +- datafusion/sqllogictest/test_files/scalar.slt | 13 ++++- 3 files changed, 54 insertions(+), 18 deletions(-) diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 8c25c57740d5f..980405179ff4c 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -24,8 +24,9 @@ use arrow::datatypes::DataType::{ Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, }; use arrow::datatypes::{ - ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, - Decimal256Type, Float32Type, Float64Type, Int32Type, + ArrowNativeTypeOp, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, + DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Decimal32Type, + Decimal64Type, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, }; use arrow::error::ArrowError; use datafusion_common::types::{ @@ -118,12 +119,28 @@ impl ScalarUDFImpl for RoundFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0].clone() { + Ok(match &arg_types[0] { Float32 => Float32, - dt @ Decimal128(_, _) - | dt @ Decimal256(_, _) - | dt @ Decimal32(_, _) - | dt @ Decimal64(_, _) => dt, + // For decimal types, increase precision by 1 to accommodate potential + // carry-over from rounding (e.g., 999.9 -> 1000.0 requires an extra digit). + // This matches PostgreSQL behavior where ROUND can increase the number + // of digits before the decimal point. + Decimal32(precision, scale) => { + let new_precision = (*precision + 1).min(DECIMAL32_MAX_PRECISION); + Decimal32(new_precision, *scale) + } + Decimal64(precision, scale) => { + let new_precision = (*precision + 1).min(DECIMAL64_MAX_PRECISION); + Decimal64(new_precision, *scale) + } + Decimal128(precision, scale) => { + let new_precision = (*precision + 1).min(DECIMAL128_MAX_PRECISION); + Decimal128(new_precision, *scale) + } + Decimal256(precision, scale) => { + let new_precision = (*precision + 1).min(DECIMAL256_MAX_PRECISION); + Decimal256(new_precision, *scale) + } _ => Float64, }) } @@ -170,26 +187,31 @@ impl ScalarUDFImpl for RoundFunc { } ScalarValue::Decimal128(Some(v), precision, scale) => { let rounded = round_decimal(*v, *scale, dp)?; + // Use increased precision from return_type to avoid overflow + let new_precision = (*precision + 1).min(DECIMAL128_MAX_PRECISION); let scalar = - ScalarValue::Decimal128(Some(rounded), *precision, *scale); + ScalarValue::Decimal128(Some(rounded), new_precision, *scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal256(Some(v), precision, scale) => { let rounded = round_decimal(*v, *scale, dp)?; + let new_precision = (*precision + 1).min(DECIMAL256_MAX_PRECISION); let scalar = - ScalarValue::Decimal256(Some(rounded), *precision, *scale); + ScalarValue::Decimal256(Some(rounded), new_precision, *scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal64(Some(v), precision, scale) => { let rounded = round_decimal(*v, *scale, dp)?; + let new_precision = (*precision + 1).min(DECIMAL64_MAX_PRECISION); let scalar = - ScalarValue::Decimal64(Some(rounded), *precision, *scale); + ScalarValue::Decimal64(Some(rounded), new_precision, *scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal32(Some(v), precision, scale) => { let rounded = round_decimal(*v, *scale, dp)?; + let new_precision = (*precision + 1).min(DECIMAL32_MAX_PRECISION); let scalar = - ScalarValue::Decimal32(Some(rounded), *precision, *scale); + ScalarValue::Decimal32(Some(rounded), new_precision, *scale); Ok(ColumnarValue::Scalar(scalar)) } _ => { @@ -251,6 +273,8 @@ fn round_columnar( result as _ } Decimal32(precision, scale) => { + // Use increased precision to avoid overflow from rounding carry-over + let new_precision = (*precision + 1).min(DECIMAL32_MAX_PRECISION); let result = calculate_binary_decimal_math::< Decimal32Type, Int32Type, @@ -260,12 +284,13 @@ fn round_columnar( value_array.as_ref(), decimal_places, |v, dp| round_decimal(v, *scale, dp), - *precision, + new_precision, *scale, )?; result as _ } Decimal64(precision, scale) => { + let new_precision = (*precision + 1).min(DECIMAL64_MAX_PRECISION); let result = calculate_binary_decimal_math::< Decimal64Type, Int32Type, @@ -275,12 +300,13 @@ fn round_columnar( value_array.as_ref(), decimal_places, |v, dp| round_decimal(v, *scale, dp), - *precision, + new_precision, *scale, )?; result as _ } Decimal128(precision, scale) => { + let new_precision = (*precision + 1).min(DECIMAL128_MAX_PRECISION); let result = calculate_binary_decimal_math::< Decimal128Type, Int32Type, @@ -290,12 +316,13 @@ fn round_columnar( value_array.as_ref(), decimal_places, |v, dp| round_decimal(v, *scale, dp), - *precision, + new_precision, *scale, )?; result as _ } Decimal256(precision, scale) => { + let new_precision = (*precision + 1).min(DECIMAL256_MAX_PRECISION); let result = calculate_binary_decimal_math::< Decimal256Type, Int32Type, @@ -305,7 +332,7 @@ fn round_columnar( value_array.as_ref(), decimal_places, |v, dp| round_decimal(v, *scale, dp), - *precision, + new_precision, *scale, )?; result as _ diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index f53f4939299c5..2df76ac9add5d 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -782,7 +782,7 @@ query TR select arrow_typeof(round(173975140545.855, 2)), round(173975140545.855, 2); ---- -Decimal128(15, 3) 173975140545.86 +Decimal128(16, 3) 173975140545.86 # smoke test for decimal parsing query RT diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index b0307c4630e20..3891a2783786d 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -951,15 +951,24 @@ query TR select arrow_typeof(round('12345.55'::decimal(10,2), -1)), round('12345.55'::decimal(10,2), -1); ---- -Decimal128(10, 2) 12350 +Decimal128(11, 2) 12350 # round decimal256 keeps decimals query TR select arrow_typeof(round('1234.5678'::decimal(50,4), 2)), round('1234.5678'::decimal(50,4), 2); ---- -Decimal256(50, 4) 1234.57 +Decimal256(51, 4) 1234.57 +# round decimal with carry-over that increases precision (issue fix) +# Previously this returned 100.0 due to precision overflow, now correctly returns 1000.0 +query TRRR +select arrow_typeof(round('999.9'::decimal(4,1))), + round('999.9'::decimal(4,1)), + round('-999.9'::decimal(4,1)), + round('99.99'::decimal(4,2)); +---- +Decimal128(5, 1) 1000 -1000 100 ## signum From 91ce44caea7f875fd2d80dcc49a7954d37d1970a Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Wed, 21 Jan 2026 23:39:54 +0530 Subject: [PATCH 2/5] handle overflow condition --- datafusion/functions/src/math/round.rs | 43 +++++++++++++------ datafusion/sqllogictest/test_files/scalar.slt | 5 +++ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 980405179ff4c..455e51a9a3253 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -186,30 +186,29 @@ impl ScalarUDFImpl for RoundFunc { Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) } ScalarValue::Decimal128(Some(v), precision, scale) => { - let rounded = round_decimal(*v, *scale, dp)?; - // Use increased precision from return_type to avoid overflow let new_precision = (*precision + 1).min(DECIMAL128_MAX_PRECISION); + let rounded = round_decimal(*v, new_precision, *scale, dp)?; let scalar = ScalarValue::Decimal128(Some(rounded), new_precision, *scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal256(Some(v), precision, scale) => { - let rounded = round_decimal(*v, *scale, dp)?; let new_precision = (*precision + 1).min(DECIMAL256_MAX_PRECISION); + let rounded = round_decimal(*v, new_precision, *scale, dp)?; let scalar = ScalarValue::Decimal256(Some(rounded), new_precision, *scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal64(Some(v), precision, scale) => { - let rounded = round_decimal(*v, *scale, dp)?; let new_precision = (*precision + 1).min(DECIMAL64_MAX_PRECISION); + let rounded = round_decimal(*v, new_precision, *scale, dp)?; let scalar = ScalarValue::Decimal64(Some(rounded), new_precision, *scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal32(Some(v), precision, scale) => { - let rounded = round_decimal(*v, *scale, dp)?; let new_precision = (*precision + 1).min(DECIMAL32_MAX_PRECISION); + let rounded = round_decimal(*v, new_precision, *scale, dp)?; let scalar = ScalarValue::Decimal32(Some(rounded), new_precision, *scale); Ok(ColumnarValue::Scalar(scalar)) @@ -283,7 +282,7 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), + |v, dp| round_decimal(v, new_precision, *scale, dp), new_precision, *scale, )?; @@ -299,7 +298,7 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), + |v, dp| round_decimal(v, new_precision, *scale, dp), new_precision, *scale, )?; @@ -315,7 +314,7 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), + |v, dp| round_decimal(v, new_precision, *scale, dp), new_precision, *scale, )?; @@ -331,7 +330,7 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), + |v, dp| round_decimal(v, new_precision, *scale, dp), new_precision, *scale, )?; @@ -361,6 +360,7 @@ where fn round_decimal( value: V, + precision: u8, scale: i8, decimal_places: i32, ) -> Result { @@ -404,9 +404,28 @@ fn round_decimal( })?; } - quotient - .mul_checked(factor) - .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into())) + let result = quotient.mul_checked(factor).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding decimal".into()) + })?; + + // Validate the result fits within the precision + // The max value for a given precision is 10^precision - 1 + let max_value = ten.pow_checked(precision as u32).map_err(|_| { + ArrowError::ComputeError(format!( + "Cannot compute max value for precision {precision}" + )) + })?; + + // Check if absolute value exceeds max (using comparison since we can't easily get abs) + // For positive: result >= max_value means overflow + // For negative: result <= -max_value means overflow + if result >= max_value || result <= max_value.neg_wrapping() { + return Err(ArrowError::ComputeError(format!( + "Decimal overflow: rounded value exceeds precision {precision}" + ))); + } + + Ok(result) } #[cfg(test)] diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 3891a2783786d..7c32087499a89 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -970,6 +970,11 @@ select arrow_typeof(round('999.9'::decimal(4,1))), ---- Decimal128(5, 1) 1000 -1000 100 +# round decimal at max precision returns error when result would overflow +# 37 nines + .9 rounded would need 38 digits, but precision is maxed at 38 +query error Decimal overflow: rounded value exceeds precision 38 +select round('9999999999999999999999999999999999999.9'::decimal(38,1)); + ## signum # signum scalar function From 7cf0bcba5a4fc899cedd4e639358ebf2bff6bb7d Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 22 Jan 2026 15:30:23 +0530 Subject: [PATCH 3/5] using const from arrow --- datafusion/functions/src/math/round.rs | 122 ++++++++++++++++++------- 1 file changed, 87 insertions(+), 35 deletions(-) diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 455e51a9a3253..0a4931c2461dd 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -27,6 +27,10 @@ use arrow::datatypes::{ ArrowNativeTypeOp, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, + MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION, + MAX_DECIMAL128_FOR_EACH_PRECISION, MAX_DECIMAL256_FOR_EACH_PRECISION, + MIN_DECIMAL32_FOR_EACH_PRECISION, MIN_DECIMAL64_FOR_EACH_PRECISION, + MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL256_FOR_EACH_PRECISION, }; use arrow::error::ArrowError; use datafusion_common::types::{ @@ -186,31 +190,38 @@ impl ScalarUDFImpl for RoundFunc { Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) } ScalarValue::Decimal128(Some(v), precision, scale) => { + // Use increased precision from return_type to avoid overflow let new_precision = (*precision + 1).min(DECIMAL128_MAX_PRECISION); - let rounded = round_decimal(*v, new_precision, *scale, dp)?; + let rounded = round_decimal(*v, *scale, dp)?; + let validated = + validate_decimal128_precision(rounded, new_precision)?; let scalar = - ScalarValue::Decimal128(Some(rounded), new_precision, *scale); + ScalarValue::Decimal128(Some(validated), new_precision, *scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal256(Some(v), precision, scale) => { let new_precision = (*precision + 1).min(DECIMAL256_MAX_PRECISION); - let rounded = round_decimal(*v, new_precision, *scale, dp)?; + let rounded = round_decimal(*v, *scale, dp)?; + let validated = + validate_decimal256_precision(rounded, new_precision)?; let scalar = - ScalarValue::Decimal256(Some(rounded), new_precision, *scale); + ScalarValue::Decimal256(Some(validated), new_precision, *scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal64(Some(v), precision, scale) => { let new_precision = (*precision + 1).min(DECIMAL64_MAX_PRECISION); - let rounded = round_decimal(*v, new_precision, *scale, dp)?; + let rounded = round_decimal(*v, *scale, dp)?; + let validated = validate_decimal64_precision(rounded, new_precision)?; let scalar = - ScalarValue::Decimal64(Some(rounded), new_precision, *scale); + ScalarValue::Decimal64(Some(validated), new_precision, *scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal32(Some(v), precision, scale) => { let new_precision = (*precision + 1).min(DECIMAL32_MAX_PRECISION); - let rounded = round_decimal(*v, new_precision, *scale, dp)?; + let rounded = round_decimal(*v, *scale, dp)?; + let validated = validate_decimal32_precision(rounded, new_precision)?; let scalar = - ScalarValue::Decimal32(Some(rounded), new_precision, *scale); + ScalarValue::Decimal32(Some(validated), new_precision, *scale); Ok(ColumnarValue::Scalar(scalar)) } _ => { @@ -282,7 +293,10 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, new_precision, *scale, dp), + |v, dp| { + round_decimal(v, *scale, dp) + .and_then(|r| validate_decimal32_precision(r, new_precision)) + }, new_precision, *scale, )?; @@ -298,7 +312,10 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, new_precision, *scale, dp), + |v, dp| { + round_decimal(v, *scale, dp) + .and_then(|r| validate_decimal64_precision(r, new_precision)) + }, new_precision, *scale, )?; @@ -314,7 +331,10 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, new_precision, *scale, dp), + |v, dp| { + round_decimal(v, *scale, dp) + .and_then(|r| validate_decimal128_precision(r, new_precision)) + }, new_precision, *scale, )?; @@ -330,7 +350,10 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, new_precision, *scale, dp), + |v, dp| { + round_decimal(v, *scale, dp) + .and_then(|r| validate_decimal256_precision(r, new_precision)) + }, new_precision, *scale, )?; @@ -358,9 +381,57 @@ where Ok((value * factor).round() / factor) } +/// Validate that an i32 (Decimal32) value fits within the specified precision. +/// Uses Arrow's pre-defined MAX/MIN_DECIMAL32_FOR_EACH_PRECISION constants. +fn validate_decimal32_precision(value: i32, precision: u8) -> Result { + let max = MAX_DECIMAL32_FOR_EACH_PRECISION[precision as usize]; + let min = MIN_DECIMAL32_FOR_EACH_PRECISION[precision as usize]; + if value > max || value < min { + return Err(ArrowError::ComputeError(format!( + "Decimal overflow: rounded value exceeds precision {precision}" + ))); + } + Ok(value) +} + +fn validate_decimal64_precision(value: i64, precision: u8) -> Result { + let max = MAX_DECIMAL64_FOR_EACH_PRECISION[precision as usize]; + let min = MIN_DECIMAL64_FOR_EACH_PRECISION[precision as usize]; + if value > max || value < min { + return Err(ArrowError::ComputeError(format!( + "Decimal overflow: rounded value exceeds precision {precision}" + ))); + } + Ok(value) +} + +fn validate_decimal128_precision(value: i128, precision: u8) -> Result { + let max = MAX_DECIMAL128_FOR_EACH_PRECISION[precision as usize]; + let min = MIN_DECIMAL128_FOR_EACH_PRECISION[precision as usize]; + if value > max || value < min { + return Err(ArrowError::ComputeError(format!( + "Decimal overflow: rounded value exceeds precision {precision}" + ))); + } + Ok(value) +} + +fn validate_decimal256_precision( + value: arrow::datatypes::i256, + precision: u8, +) -> Result { + let max = MAX_DECIMAL256_FOR_EACH_PRECISION[precision as usize]; + let min = MIN_DECIMAL256_FOR_EACH_PRECISION[precision as usize]; + if value > max || value < min { + return Err(ArrowError::ComputeError(format!( + "Decimal overflow: rounded value exceeds precision {precision}" + ))); + } + Ok(value) +} + fn round_decimal( value: V, - precision: u8, scale: i8, decimal_places: i32, ) -> Result { @@ -404,28 +475,9 @@ fn round_decimal( })?; } - let result = quotient.mul_checked(factor).map_err(|_| { - ArrowError::ComputeError("Overflow while rounding decimal".into()) - })?; - - // Validate the result fits within the precision - // The max value for a given precision is 10^precision - 1 - let max_value = ten.pow_checked(precision as u32).map_err(|_| { - ArrowError::ComputeError(format!( - "Cannot compute max value for precision {precision}" - )) - })?; - - // Check if absolute value exceeds max (using comparison since we can't easily get abs) - // For positive: result >= max_value means overflow - // For negative: result <= -max_value means overflow - if result >= max_value || result <= max_value.neg_wrapping() { - return Err(ArrowError::ComputeError(format!( - "Decimal overflow: rounded value exceeds precision {precision}" - ))); - } - - Ok(result) + quotient + .mul_checked(factor) + .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into())) } #[cfg(test)] From 356888431c1be218e6f3fb016b41bf618bf835c4 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 22 Jan 2026 22:56:08 +0530 Subject: [PATCH 4/5] scale based on the input --- datafusion/functions/src/math/round.rs | 284 +++++++++--------- .../sqllogictest/test_files/decimal.slt | 2 +- datafusion/sqllogictest/test_files/scalar.slt | 23 +- 3 files changed, 160 insertions(+), 149 deletions(-) diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 0a4931c2461dd..3ca23cfd9a583 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -24,14 +24,10 @@ use arrow::datatypes::DataType::{ Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, }; use arrow::datatypes::{ - ArrowNativeTypeOp, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, - DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Decimal32Type, - Decimal64Type, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, - MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION, - MAX_DECIMAL128_FOR_EACH_PRECISION, MAX_DECIMAL256_FOR_EACH_PRECISION, - MIN_DECIMAL32_FOR_EACH_PRECISION, MIN_DECIMAL64_FOR_EACH_PRECISION, - MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL256_FOR_EACH_PRECISION, + ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, + Decimal256Type, Float32Type, Float64Type, Int32Type, }; +use arrow::datatypes::{Field, FieldRef}; use arrow::error::ArrowError; use datafusion_common::types::{ NativeType, logical_float32, logical_float64, logical_int32, @@ -39,10 +35,11 @@ use datafusion_common::types::{ use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignature, TypeSignatureClass, Volatility, + Coercion, ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; +use std::sync::Arc; #[user_doc( doc_section(label = "Math Functions"), @@ -122,31 +119,74 @@ impl ScalarUDFImpl for RoundFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match &arg_types[0] { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let input_field = &args.arg_fields[0]; + let input_type = input_field.data_type(); + + // Get decimal_places from scalar_arguments + // If dp is not a constant scalar, we must keep the original scale because + // we can't determine a single output scale for varying per-row dp values. + let (decimal_places, dp_is_scalar): (i32, bool) = + if args.scalar_arguments.len() > 1 { + match args.scalar_arguments[1] { + Some(ScalarValue::Int32(Some(v))) => (*v, true), + Some(ScalarValue::Int64(Some(v))) => (*v as i32, true), + _ => (0, false), // dp is a column or null - can't determine scale + } + } else { + (0, true) // No dp argument means default to 0 + }; + + // Calculate return type based on input type + // For decimals: reduce scale to decimal_places (reclaims precision for integer part) + // This matches Spark/DuckDB behavior where ROUND adjusts the scale + // BUT only if dp is a constant - otherwise keep original scale + let return_type = match input_type { Float32 => Float32, - // For decimal types, increase precision by 1 to accommodate potential - // carry-over from rounding (e.g., 999.9 -> 1000.0 requires an extra digit). - // This matches PostgreSQL behavior where ROUND can increase the number - // of digits before the decimal point. Decimal32(precision, scale) => { - let new_precision = (*precision + 1).min(DECIMAL32_MAX_PRECISION); - Decimal32(new_precision, *scale) + if dp_is_scalar { + let new_scale = (*scale).min(decimal_places.max(0) as i8); + Decimal32(*precision, new_scale) + } else { + Decimal32(*precision, *scale) + } } Decimal64(precision, scale) => { - let new_precision = (*precision + 1).min(DECIMAL64_MAX_PRECISION); - Decimal64(new_precision, *scale) + if dp_is_scalar { + let new_scale = (*scale).min(decimal_places.max(0) as i8); + Decimal64(*precision, new_scale) + } else { + Decimal64(*precision, *scale) + } } Decimal128(precision, scale) => { - let new_precision = (*precision + 1).min(DECIMAL128_MAX_PRECISION); - Decimal128(new_precision, *scale) + if dp_is_scalar { + let new_scale = (*scale).min(decimal_places.max(0) as i8); + Decimal128(*precision, new_scale) + } else { + Decimal128(*precision, *scale) + } } Decimal256(precision, scale) => { - let new_precision = (*precision + 1).min(DECIMAL256_MAX_PRECISION); - Decimal256(new_precision, *scale) + if dp_is_scalar { + let new_scale = (*scale).min(decimal_places.max(0) as i8); + Decimal256(*precision, new_scale) + } else { + Decimal256(*precision, *scale) + } } _ => Float64, - }) + }; + + Ok(Arc::new(Field::new( + self.name(), + return_type, + input_field.is_nullable(), + ))) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("use return_field_from_args instead") } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -162,7 +202,6 @@ impl ScalarUDFImpl for RoundFunc { &default_decimal_places }; - // Scalar fast path for float and decimal types - avoid array conversion overhead if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) = (&args.args[0], decimal_places) { @@ -190,38 +229,32 @@ impl ScalarUDFImpl for RoundFunc { Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) } ScalarValue::Decimal128(Some(v), precision, scale) => { - // Use increased precision from return_type to avoid overflow - let new_precision = (*precision + 1).min(DECIMAL128_MAX_PRECISION); - let rounded = round_decimal(*v, *scale, dp)?; - let validated = - validate_decimal128_precision(rounded, new_precision)?; + // Reduce scale to reclaim integer precision + let new_scale = (*scale).min(dp.max(0) as i8); + let rounded = round_decimal(*v, *scale, new_scale, dp)?; let scalar = - ScalarValue::Decimal128(Some(validated), new_precision, *scale); + ScalarValue::Decimal128(Some(rounded), *precision, new_scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal256(Some(v), precision, scale) => { - let new_precision = (*precision + 1).min(DECIMAL256_MAX_PRECISION); - let rounded = round_decimal(*v, *scale, dp)?; - let validated = - validate_decimal256_precision(rounded, new_precision)?; + let new_scale = (*scale).min(dp.max(0) as i8); + let rounded = round_decimal(*v, *scale, new_scale, dp)?; let scalar = - ScalarValue::Decimal256(Some(validated), new_precision, *scale); + ScalarValue::Decimal256(Some(rounded), *precision, new_scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal64(Some(v), precision, scale) => { - let new_precision = (*precision + 1).min(DECIMAL64_MAX_PRECISION); - let rounded = round_decimal(*v, *scale, dp)?; - let validated = validate_decimal64_precision(rounded, new_precision)?; + let new_scale = (*scale).min(dp.max(0) as i8); + let rounded = round_decimal(*v, *scale, new_scale, dp)?; let scalar = - ScalarValue::Decimal64(Some(validated), new_precision, *scale); + ScalarValue::Decimal64(Some(rounded), *precision, new_scale); Ok(ColumnarValue::Scalar(scalar)) } ScalarValue::Decimal32(Some(v), precision, scale) => { - let new_precision = (*precision + 1).min(DECIMAL32_MAX_PRECISION); - let rounded = round_decimal(*v, *scale, dp)?; - let validated = validate_decimal32_precision(rounded, new_precision)?; + let new_scale = (*scale).min(dp.max(0) as i8); + let rounded = round_decimal(*v, *scale, new_scale, dp)?; let scalar = - ScalarValue::Decimal32(Some(validated), new_precision, *scale); + ScalarValue::Decimal32(Some(rounded), *precision, new_scale); Ok(ColumnarValue::Scalar(scalar)) } _ => { @@ -232,7 +265,12 @@ impl ScalarUDFImpl for RoundFunc { } } } else { - round_columnar(&args.args[0], decimal_places, args.number_rows) + round_columnar( + &args.args[0], + decimal_places, + args.number_rows, + args.return_type(), + ) } } @@ -260,13 +298,14 @@ fn round_columnar( value: &ColumnarValue, decimal_places: &ColumnarValue, number_rows: usize, + return_type: &DataType, ) -> Result { let value_array = value.to_array(number_rows)?; let both_scalars = matches!(value, ColumnarValue::Scalar(_)) && matches!(decimal_places, ColumnarValue::Scalar(_)); - let arr: ArrayRef = match value_array.data_type() { - Float64 => { + let arr: ArrayRef = match (value_array.data_type(), return_type) { + (Float64, _) => { let result = calculate_binary_math::( value_array.as_ref(), decimal_places, @@ -274,7 +313,7 @@ fn round_columnar( )?; result as _ } - Float32 => { + (Float32, _) => { let result = calculate_binary_math::( value_array.as_ref(), decimal_places, @@ -282,9 +321,8 @@ fn round_columnar( )?; result as _ } - Decimal32(precision, scale) => { - // Use increased precision to avoid overflow from rounding carry-over - let new_precision = (*precision + 1).min(DECIMAL32_MAX_PRECISION); + (Decimal32(_, scale), Decimal32(precision, new_scale)) => { + // reduce scale to reclaim integer precision let result = calculate_binary_decimal_math::< Decimal32Type, Int32Type, @@ -293,17 +331,13 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| { - round_decimal(v, *scale, dp) - .and_then(|r| validate_decimal32_precision(r, new_precision)) - }, - new_precision, - *scale, + |v, dp| round_decimal(v, *scale, *new_scale, dp), + *precision, + *new_scale, )?; result as _ } - Decimal64(precision, scale) => { - let new_precision = (*precision + 1).min(DECIMAL64_MAX_PRECISION); + (Decimal64(_, scale), Decimal64(precision, new_scale)) => { let result = calculate_binary_decimal_math::< Decimal64Type, Int32Type, @@ -312,17 +346,13 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| { - round_decimal(v, *scale, dp) - .and_then(|r| validate_decimal64_precision(r, new_precision)) - }, - new_precision, - *scale, + |v, dp| round_decimal(v, *scale, *new_scale, dp), + *precision, + *new_scale, )?; result as _ } - Decimal128(precision, scale) => { - let new_precision = (*precision + 1).min(DECIMAL128_MAX_PRECISION); + (Decimal128(_, scale), Decimal128(precision, new_scale)) => { let result = calculate_binary_decimal_math::< Decimal128Type, Int32Type, @@ -331,17 +361,13 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| { - round_decimal(v, *scale, dp) - .and_then(|r| validate_decimal128_precision(r, new_precision)) - }, - new_precision, - *scale, + |v, dp| round_decimal(v, *scale, *new_scale, dp), + *precision, + *new_scale, )?; result as _ } - Decimal256(precision, scale) => { - let new_precision = (*precision + 1).min(DECIMAL256_MAX_PRECISION); + (Decimal256(_, scale), Decimal256(precision, new_scale)) => { let result = calculate_binary_decimal_math::< Decimal256Type, Int32Type, @@ -350,16 +376,13 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| { - round_decimal(v, *scale, dp) - .and_then(|r| validate_decimal256_precision(r, new_precision)) - }, - new_precision, - *scale, + |v, dp| round_decimal(v, *scale, *new_scale, dp), + *precision, + *new_scale, )?; result as _ } - other => exec_err!("Unsupported data type {other:?} for function round")?, + (other, _) => exec_err!("Unsupported data type {other:?} for function round")?, }; if both_scalars { @@ -381,61 +404,13 @@ where Ok((value * factor).round() / factor) } -/// Validate that an i32 (Decimal32) value fits within the specified precision. -/// Uses Arrow's pre-defined MAX/MIN_DECIMAL32_FOR_EACH_PRECISION constants. -fn validate_decimal32_precision(value: i32, precision: u8) -> Result { - let max = MAX_DECIMAL32_FOR_EACH_PRECISION[precision as usize]; - let min = MIN_DECIMAL32_FOR_EACH_PRECISION[precision as usize]; - if value > max || value < min { - return Err(ArrowError::ComputeError(format!( - "Decimal overflow: rounded value exceeds precision {precision}" - ))); - } - Ok(value) -} - -fn validate_decimal64_precision(value: i64, precision: u8) -> Result { - let max = MAX_DECIMAL64_FOR_EACH_PRECISION[precision as usize]; - let min = MIN_DECIMAL64_FOR_EACH_PRECISION[precision as usize]; - if value > max || value < min { - return Err(ArrowError::ComputeError(format!( - "Decimal overflow: rounded value exceeds precision {precision}" - ))); - } - Ok(value) -} - -fn validate_decimal128_precision(value: i128, precision: u8) -> Result { - let max = MAX_DECIMAL128_FOR_EACH_PRECISION[precision as usize]; - let min = MIN_DECIMAL128_FOR_EACH_PRECISION[precision as usize]; - if value > max || value < min { - return Err(ArrowError::ComputeError(format!( - "Decimal overflow: rounded value exceeds precision {precision}" - ))); - } - Ok(value) -} - -fn validate_decimal256_precision( - value: arrow::datatypes::i256, - precision: u8, -) -> Result { - let max = MAX_DECIMAL256_FOR_EACH_PRECISION[precision as usize]; - let min = MIN_DECIMAL256_FOR_EACH_PRECISION[precision as usize]; - if value > max || value < min { - return Err(ArrowError::ComputeError(format!( - "Decimal overflow: rounded value exceeds precision {precision}" - ))); - } - Ok(value) -} - fn round_decimal( value: V, - scale: i8, + input_scale: i8, + output_scale: i8, decimal_places: i32, ) -> Result { - let diff = i64::from(scale) - i64::from(decimal_places); + let diff = i64::from(input_scale) - i64::from(decimal_places); if diff <= 0 { return Ok(value); } @@ -456,7 +431,7 @@ fn round_decimal( let factor = ten.pow_checked(diff).map_err(|_| { ArrowError::ComputeError(format!( - "Overflow while rounding decimal with scale {scale} and decimal places {decimal_places}" + "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}" )) })?; @@ -475,9 +450,40 @@ fn round_decimal( })?; } - quotient - .mul_checked(factor) - .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into())) + // Determine how to scale the result based on output_scale vs computed scale + // computed_scale = max(0, min(input_scale, decimal_places)) + let computed_scale = if decimal_places >= 0 { + (input_scale as i32).min(decimal_places).max(0) as i8 + } else { + 0 + }; + + if output_scale == computed_scale { + // scale reduction, return quotient directly (or shifted for negative dp) + if decimal_places >= 0 { + Ok(quotient) + } else { + // For negative decimal_places, multiply by 10^(-decimal_places) to shift left + let neg_dp: u32 = (-decimal_places).try_into().map_err(|_| { + ArrowError::ComputeError(format!( + "Invalid negative decimal places: {decimal_places}" + )) + })?; + let shift_factor = ten.pow_checked(neg_dp).map_err(|_| { + ArrowError::ComputeError(format!( + "Overflow computing shift factor for decimal places {decimal_places}" + )) + })?; + quotient.mul_checked(shift_factor).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding decimal".into()) + }) + } + } else { + // Keep original scale behavior: multiply back by factor + quotient.mul_checked(factor).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding decimal".into()) + }) + } } #[cfg(test)] @@ -495,12 +501,14 @@ mod test { decimal_places: Option, ) -> Result { let number_rows = value.len(); + let return_type = value.data_type().clone(); let value = ColumnarValue::Array(value); let decimal_places = decimal_places .map(ColumnarValue::Array) .unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0)))); - let result = super::round_columnar(&value, &decimal_places, number_rows)?; + let result = + super::round_columnar(&value, &decimal_places, number_rows, &return_type)?; match result { ColumnarValue::Array(array) => Ok(array), ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1), diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index 2df76ac9add5d..eca2c88bb5f89 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -782,7 +782,7 @@ query TR select arrow_typeof(round(173975140545.855, 2)), round(173975140545.855, 2); ---- -Decimal128(16, 3) 173975140545.86 +Decimal128(15, 2) 173975140545.86 # smoke test for decimal parsing query RT diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 7c32087499a89..58b610867f874 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -931,11 +931,12 @@ query error Arrow error: Cast error: Can't cast value 2147483649 to type Int32 select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14, 2147483649); # round decimal should not cast to float +# scale reduces to match decimal_places query TR select arrow_typeof(round('173975140545.855'::decimal(38,10), 2)), round('173975140545.855'::decimal(38,10), 2); ---- -Decimal128(38, 10) 173975140545.86 +Decimal128(38, 2) 173975140545.86 # round decimal ties away from zero query RRRR @@ -951,29 +952,31 @@ query TR select arrow_typeof(round('12345.55'::decimal(10,2), -1)), round('12345.55'::decimal(10,2), -1); ---- -Decimal128(11, 2) 12350 +Decimal128(10, 0) 12350 # round decimal256 keeps decimals query TR select arrow_typeof(round('1234.5678'::decimal(50,4), 2)), round('1234.5678'::decimal(50,4), 2); ---- -Decimal256(51, 4) 1234.57 +Decimal256(50, 2) 1234.57 -# round decimal with carry-over that increases precision (issue fix) -# Previously this returned 100.0 due to precision overflow, now correctly returns 1000.0 +# round decimal with carry-over (reduce scale) +# Scale reduces from 1 to 0, allowing extra digit for carry-over query TRRR select arrow_typeof(round('999.9'::decimal(4,1))), round('999.9'::decimal(4,1)), round('-999.9'::decimal(4,1)), round('99.99'::decimal(4,2)); ---- -Decimal128(5, 1) 1000 -1000 100 +Decimal128(4, 0) 1000 -1000 100 -# round decimal at max precision returns error when result would overflow -# 37 nines + .9 rounded would need 38 digits, but precision is maxed at 38 -query error Decimal overflow: rounded value exceeds precision 38 -select round('9999999999999999999999999999999999999.9'::decimal(38,1)); +# round decimal at max precision now works (scale reduction handles overflow) +query TR +select arrow_typeof(round('9999999999999999999999999999999999999.9'::decimal(38,1))), + round('9999999999999999999999999999999999999.9'::decimal(38,1)); +---- +Decimal128(38, 0) 10000000000000000000000000000000000000 ## signum From d4bf1289a7a460bf4910090d34bc6136ba7b3a4f Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Fri, 23 Jan 2026 21:23:50 +0530 Subject: [PATCH 5/5] handle casting and return type assumptions --- datafusion/functions/src/math/round.rs | 370 +++++++++++++++--- datafusion/sqllogictest/test_files/scalar.slt | 22 +- 2 files changed, 339 insertions(+), 53 deletions(-) diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 3ca23cfd9a583..493dc9899a606 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -24,8 +24,13 @@ use arrow::datatypes::DataType::{ Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, }; use arrow::datatypes::{ - ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, - Decimal256Type, Float32Type, Float64Type, Int32Type, + ArrowNativeTypeOp, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, + DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Decimal32Type, + Decimal64Type, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, + MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION, + MAX_DECIMAL128_FOR_EACH_PRECISION, MAX_DECIMAL256_FOR_EACH_PRECISION, + MIN_DECIMAL32_FOR_EACH_PRECISION, MIN_DECIMAL64_FOR_EACH_PRECISION, + MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL256_FOR_EACH_PRECISION, i256, }; use arrow::datatypes::{Field, FieldRef}; use arrow::error::ArrowError; @@ -41,6 +46,91 @@ use datafusion_expr::{ use datafusion_macros::user_doc; use std::sync::Arc; +fn output_scale_for_decimal(input_scale: i8, decimal_places: i32) -> Result { + let new_scale = i32::from(input_scale).min(decimal_places.max(0)); + i8::try_from(new_scale).map_err(|_| { + datafusion_common::DataFusionError::Internal(format!( + "Computed decimal scale {new_scale} is out of range for i8" + )) + }) +} + +fn validate_decimal32_precision(value: i32, precision: u8) -> Result { + let max = MAX_DECIMAL32_FOR_EACH_PRECISION + .get(precision as usize) + .ok_or_else(|| { + ArrowError::ComputeError(format!("Invalid decimal precision {precision}")) + })?; + let min = MIN_DECIMAL32_FOR_EACH_PRECISION + .get(precision as usize) + .ok_or_else(|| { + ArrowError::ComputeError(format!("Invalid decimal precision {precision}")) + })?; + if value > *max || value < *min { + return Err(ArrowError::ComputeError(format!( + "Decimal overflow: rounded value exceeds precision {precision}" + ))); + } + Ok(value) +} + +fn validate_decimal64_precision(value: i64, precision: u8) -> Result { + let max = MAX_DECIMAL64_FOR_EACH_PRECISION + .get(precision as usize) + .ok_or_else(|| { + ArrowError::ComputeError(format!("Invalid decimal precision {precision}")) + })?; + let min = MIN_DECIMAL64_FOR_EACH_PRECISION + .get(precision as usize) + .ok_or_else(|| { + ArrowError::ComputeError(format!("Invalid decimal precision {precision}")) + })?; + if value > *max || value < *min { + return Err(ArrowError::ComputeError(format!( + "Decimal overflow: rounded value exceeds precision {precision}" + ))); + } + Ok(value) +} + +fn validate_decimal128_precision(value: i128, precision: u8) -> Result { + let max = MAX_DECIMAL128_FOR_EACH_PRECISION + .get(precision as usize) + .ok_or_else(|| { + ArrowError::ComputeError(format!("Invalid decimal precision {precision}")) + })?; + let min = MIN_DECIMAL128_FOR_EACH_PRECISION + .get(precision as usize) + .ok_or_else(|| { + ArrowError::ComputeError(format!("Invalid decimal precision {precision}")) + })?; + if value > *max || value < *min { + return Err(ArrowError::ComputeError(format!( + "Decimal overflow: rounded value exceeds precision {precision}" + ))); + } + Ok(value) +} + +fn validate_decimal256_precision(value: i256, precision: u8) -> Result { + let max = MAX_DECIMAL256_FOR_EACH_PRECISION + .get(precision as usize) + .ok_or_else(|| { + ArrowError::ComputeError(format!("Invalid decimal precision {precision}")) + })?; + let min = MIN_DECIMAL256_FOR_EACH_PRECISION + .get(precision as usize) + .ok_or_else(|| { + ArrowError::ComputeError(format!("Invalid decimal precision {precision}")) + })?; + if value > *max || value < *min { + return Err(ArrowError::ComputeError(format!( + "Decimal overflow: rounded value exceeds precision {precision}" + ))); + } + Ok(value) +} + #[user_doc( doc_section(label = "Math Functions"), description = "Rounds a number to the nearest integer.", @@ -126,63 +216,124 @@ impl ScalarUDFImpl for RoundFunc { // Get decimal_places from scalar_arguments // If dp is not a constant scalar, we must keep the original scale because // we can't determine a single output scale for varying per-row dp values. - let (decimal_places, dp_is_scalar): (i32, bool) = - if args.scalar_arguments.len() > 1 { - match args.scalar_arguments[1] { - Some(ScalarValue::Int32(Some(v))) => (*v, true), - Some(ScalarValue::Int64(Some(v))) => (*v as i32, true), - _ => (0, false), // dp is a column or null - can't determine scale - } - } else { - (0, true) // No dp argument means default to 0 - }; + let (decimal_places, dp_is_scalar) = match args.scalar_arguments.get(1) { + None => (0, true), // No dp argument means default to 0 + Some(None) => (0, false), // dp is a column + Some(Some(ScalarValue::Int32(Some(v)))) => (*v, true), + Some(Some(ScalarValue::Int64(Some(v)))) => { + let decimal_places = *v; + let v = i32::try_from(decimal_places).map_err(|_| { + datafusion_common::DataFusionError::Execution(format!( + "round decimal_places {decimal_places} is out of supported i32 range" + )) + })?; + (v, true) + } + Some(Some(scalar)) if scalar.is_null() => (0, true), // null dp => output is null + Some(Some(other)) => { + return exec_err!( + "Unexpected datatype for decimal_places: {}", + other.data_type() + ); + } + }; // Calculate return type based on input type // For decimals: reduce scale to decimal_places (reclaims precision for integer part) // This matches Spark/DuckDB behavior where ROUND adjusts the scale - // BUT only if dp is a constant - otherwise keep original scale + // BUT only if dp is a constant - otherwise keep original scale and add + // extra precision to accommodate potential carry-over. let return_type = match input_type { Float32 => Float32, Decimal32(precision, scale) => { if dp_is_scalar { - let new_scale = (*scale).min(decimal_places.max(0) as i8); - Decimal32(*precision, new_scale) + let new_scale = output_scale_for_decimal(*scale, decimal_places)?; + let new_precision = if *scale == 0 + && decimal_places < 0 + && decimal_places + .checked_neg() + .map(|abs| abs <= i32::from(*precision)) + .unwrap_or(false) + { + precision.saturating_add(1).min(DECIMAL32_MAX_PRECISION) + } else { + *precision + }; + Decimal32(new_precision, new_scale) } else { - Decimal32(*precision, *scale) + let new_precision = + precision.saturating_add(1).min(DECIMAL32_MAX_PRECISION); + Decimal32(new_precision, *scale) } } Decimal64(precision, scale) => { if dp_is_scalar { - let new_scale = (*scale).min(decimal_places.max(0) as i8); - Decimal64(*precision, new_scale) + let new_scale = output_scale_for_decimal(*scale, decimal_places)?; + let new_precision = if *scale == 0 + && decimal_places < 0 + && decimal_places + .checked_neg() + .map(|abs| abs <= i32::from(*precision)) + .unwrap_or(false) + { + precision.saturating_add(1).min(DECIMAL64_MAX_PRECISION) + } else { + *precision + }; + Decimal64(new_precision, new_scale) } else { - Decimal64(*precision, *scale) + let new_precision = + precision.saturating_add(1).min(DECIMAL64_MAX_PRECISION); + Decimal64(new_precision, *scale) } } Decimal128(precision, scale) => { if dp_is_scalar { - let new_scale = (*scale).min(decimal_places.max(0) as i8); - Decimal128(*precision, new_scale) + let new_scale = output_scale_for_decimal(*scale, decimal_places)?; + let new_precision = if *scale == 0 + && decimal_places < 0 + && decimal_places + .checked_neg() + .map(|abs| abs <= i32::from(*precision)) + .unwrap_or(false) + { + precision.saturating_add(1).min(DECIMAL128_MAX_PRECISION) + } else { + *precision + }; + Decimal128(new_precision, new_scale) } else { - Decimal128(*precision, *scale) + let new_precision = + precision.saturating_add(1).min(DECIMAL128_MAX_PRECISION); + Decimal128(new_precision, *scale) } } Decimal256(precision, scale) => { if dp_is_scalar { - let new_scale = (*scale).min(decimal_places.max(0) as i8); - Decimal256(*precision, new_scale) + let new_scale = output_scale_for_decimal(*scale, decimal_places)?; + let new_precision = if *scale == 0 + && decimal_places < 0 + && decimal_places + .checked_neg() + .map(|abs| abs <= i32::from(*precision)) + .unwrap_or(false) + { + precision.saturating_add(1).min(DECIMAL256_MAX_PRECISION) + } else { + *precision + }; + Decimal256(new_precision, new_scale) } else { - Decimal256(*precision, *scale) + let new_precision = + precision.saturating_add(1).min(DECIMAL256_MAX_PRECISION); + Decimal256(new_precision, *scale) } } _ => Float64, }; - Ok(Arc::new(Field::new( - self.name(), - return_type, - input_field.is_nullable(), - ))) + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new(self.name(), return_type, nullable))) } fn return_type(&self, _arg_types: &[DataType]) -> Result { @@ -212,6 +363,13 @@ impl ScalarUDFImpl for RoundFunc { let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar { *dp + } else if let ScalarValue::Int64(Some(dp)) = dp_scalar { + let decimal_places = *dp; + i32::try_from(decimal_places).map_err(|_| { + datafusion_common::DataFusionError::Execution(format!( + "round decimal_places {decimal_places} is out of supported i32 range" + )) + })? } else { return internal_err!( "Unexpected datatype for decimal_places: {}", @@ -228,33 +386,96 @@ impl ScalarUDFImpl for RoundFunc { let rounded = round_float(*v, dp)?; Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) } - ScalarValue::Decimal128(Some(v), precision, scale) => { - // Reduce scale to reclaim integer precision - let new_scale = (*scale).min(dp.max(0) as i8); - let rounded = round_decimal(*v, *scale, new_scale, dp)?; + ScalarValue::Decimal128(Some(v), _precision, scale) => { + let (out_precision, out_scale) = + if let Decimal128(p, s) = args.return_type() { + (*p, *s) + } else { + return internal_err!( + "Unexpected return type for decimal128 round: {}", + args.return_type() + ); + }; + let rounded = round_decimal(*v, *scale, out_scale, dp)?; + let rounded = if out_precision == DECIMAL128_MAX_PRECISION + && *scale == 0 + && dp < 0 + { + validate_decimal128_precision(rounded, out_precision) + } else { + Ok(rounded) + }?; let scalar = - ScalarValue::Decimal128(Some(rounded), *precision, new_scale); + ScalarValue::Decimal128(Some(rounded), out_precision, out_scale); Ok(ColumnarValue::Scalar(scalar)) } - ScalarValue::Decimal256(Some(v), precision, scale) => { - let new_scale = (*scale).min(dp.max(0) as i8); - let rounded = round_decimal(*v, *scale, new_scale, dp)?; + ScalarValue::Decimal256(Some(v), _precision, scale) => { + let (out_precision, out_scale) = + if let Decimal256(p, s) = args.return_type() { + (*p, *s) + } else { + return internal_err!( + "Unexpected return type for decimal256 round: {}", + args.return_type() + ); + }; + let rounded = round_decimal(*v, *scale, out_scale, dp)?; + let rounded = if out_precision == DECIMAL256_MAX_PRECISION + && *scale == 0 + && dp < 0 + { + validate_decimal256_precision(rounded, out_precision) + } else { + Ok(rounded) + }?; let scalar = - ScalarValue::Decimal256(Some(rounded), *precision, new_scale); + ScalarValue::Decimal256(Some(rounded), out_precision, out_scale); Ok(ColumnarValue::Scalar(scalar)) } - ScalarValue::Decimal64(Some(v), precision, scale) => { - let new_scale = (*scale).min(dp.max(0) as i8); - let rounded = round_decimal(*v, *scale, new_scale, dp)?; + ScalarValue::Decimal64(Some(v), _precision, scale) => { + let (out_precision, out_scale) = + if let Decimal64(p, s) = args.return_type() { + (*p, *s) + } else { + return internal_err!( + "Unexpected return type for decimal64 round: {}", + args.return_type() + ); + }; + let rounded = round_decimal(*v, *scale, out_scale, dp)?; + let rounded = if out_precision == DECIMAL64_MAX_PRECISION + && *scale == 0 + && dp < 0 + { + validate_decimal64_precision(rounded, out_precision) + } else { + Ok(rounded) + }?; let scalar = - ScalarValue::Decimal64(Some(rounded), *precision, new_scale); + ScalarValue::Decimal64(Some(rounded), out_precision, out_scale); Ok(ColumnarValue::Scalar(scalar)) } - ScalarValue::Decimal32(Some(v), precision, scale) => { - let new_scale = (*scale).min(dp.max(0) as i8); - let rounded = round_decimal(*v, *scale, new_scale, dp)?; + ScalarValue::Decimal32(Some(v), _precision, scale) => { + let (out_precision, out_scale) = + if let Decimal32(p, s) = args.return_type() { + (*p, *s) + } else { + return internal_err!( + "Unexpected return type for decimal32 round: {}", + args.return_type() + ); + }; + let rounded = round_decimal(*v, *scale, out_scale, dp)?; + let rounded = if out_precision == DECIMAL32_MAX_PRECISION + && *scale == 0 + && dp < 0 + { + validate_decimal32_precision(rounded, out_precision) + } else { + Ok(rounded) + }?; let scalar = - ScalarValue::Decimal32(Some(rounded), *precision, new_scale); + ScalarValue::Decimal32(Some(rounded), out_precision, out_scale); Ok(ColumnarValue::Scalar(scalar)) } _ => { @@ -303,6 +524,7 @@ fn round_columnar( let value_array = value.to_array(number_rows)?; let both_scalars = matches!(value, ColumnarValue::Scalar(_)) && matches!(decimal_places, ColumnarValue::Scalar(_)); + let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_)); let arr: ArrayRef = match (value_array.data_type(), return_type) { (Float64, _) => { @@ -331,7 +553,16 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, *new_scale, dp), + |v, dp| { + let rounded = round_decimal(v, *scale, *new_scale, dp)?; + if *precision == DECIMAL32_MAX_PRECISION + && (decimal_places_is_array || (*scale == 0 && dp < 0)) + { + validate_decimal32_precision(rounded, *precision) + } else { + Ok(rounded) + } + }, *precision, *new_scale, )?; @@ -346,7 +577,16 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, *new_scale, dp), + |v, dp| { + let rounded = round_decimal(v, *scale, *new_scale, dp)?; + if *precision == DECIMAL64_MAX_PRECISION + && (decimal_places_is_array || (*scale == 0 && dp < 0)) + { + validate_decimal64_precision(rounded, *precision) + } else { + Ok(rounded) + } + }, *precision, *new_scale, )?; @@ -361,7 +601,16 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, *new_scale, dp), + |v, dp| { + let rounded = round_decimal(v, *scale, *new_scale, dp)?; + if *precision == DECIMAL128_MAX_PRECISION + && (decimal_places_is_array || (*scale == 0 && dp < 0)) + { + validate_decimal128_precision(rounded, *precision) + } else { + Ok(rounded) + } + }, *precision, *new_scale, )?; @@ -376,7 +625,16 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, *new_scale, dp), + |v, dp| { + let rounded = round_decimal(v, *scale, *new_scale, dp)?; + if *precision == DECIMAL256_MAX_PRECISION + && (decimal_places_is_array || (*scale == 0 && dp < 0)) + { + validate_decimal256_precision(rounded, *precision) + } else { + Ok(rounded) + } + }, *precision, *new_scale, )?; @@ -453,7 +711,12 @@ fn round_decimal( // Determine how to scale the result based on output_scale vs computed scale // computed_scale = max(0, min(input_scale, decimal_places)) let computed_scale = if decimal_places >= 0 { - (input_scale as i32).min(decimal_places).max(0) as i8 + let new_scale = i32::from(input_scale).min(decimal_places).max(0); + i8::try_from(new_scale).map_err(|_| { + ArrowError::ComputeError(format!( + "Computed decimal scale {new_scale} is out of range for i8" + )) + })? } else { 0 }; @@ -501,6 +764,9 @@ mod test { decimal_places: Option, ) -> Result { let number_rows = value.len(); + // NOTE: For decimal inputs, the actual ROUND return type can differ from the + // input type (scale reduction for literal `decimal_places`). These unit tests + // only exercise Float32/Float64 behavior. let return_type = value.data_type().clone(); let value = ColumnarValue::Array(value); let decimal_places = decimal_places diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 58b610867f874..c933d9c29e6d0 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -923,7 +923,7 @@ select round(a), round(b), round(c) from small_floats; # round with too large # max Int32 is 2147483647 -query error Arrow error: Cast error: Can't cast value 2147483648 to type Int32 +query error round decimal_places 2147483648 is out of supported i32 range select round(3.14, 2147483648); # with array @@ -954,6 +954,13 @@ select arrow_typeof(round('12345.55'::decimal(10,2), -1)), ---- Decimal128(10, 0) 12350 +# round decimal scale 0 negative places (carry can require extra precision) +query TR +select arrow_typeof(round('99'::decimal(2,0), -1)), + round('99'::decimal(2,0), -1); +---- +Decimal128(3, 0) 100 + # round decimal256 keeps decimals query TR select arrow_typeof(round('1234.5678'::decimal(50,4), 2)), @@ -971,6 +978,14 @@ select arrow_typeof(round('999.9'::decimal(4,1))), ---- Decimal128(4, 0) 1000 -1000 100 +# round decimal with carry-over and non-literal decimal_places (increase precision) +# Scale can't be reduced when decimal_places isn't a constant, so precision increases. +query TR +select arrow_typeof(round(val, dp)), round(val, dp) +from (values (cast('999.9' as decimal(4,1)), 0)) as t(val, dp); +---- +Decimal128(5, 1) 1000 + # round decimal at max precision now works (scale reduction handles overflow) query TR select arrow_typeof(round('9999999999999999999999999999999999999.9'::decimal(38,1))), @@ -978,6 +993,11 @@ select arrow_typeof(round('9999999999999999999999999999999999999.9'::decimal(38, ---- Decimal128(38, 0) 10000000000000000000000000000000000000 +# round decimal at max precision with non-literal decimal_places can overflow +query error Decimal overflow: rounded value exceeds precision 38 +select round(val, dp) +from (values (cast('9999999999999999999999999999999999999.9' as decimal(38,1)), 0)) as t(val, dp); + ## signum # signum scalar function