Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 148 additions & 42 deletions datafusion/functions/src/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,19 @@ use arrow::datatypes::{
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
Decimal256Type, Float32Type, Float64Type, Int32Type,
};
use arrow::datatypes::{Field, FieldRef};
use arrow::error::ArrowError;
use datafusion_common::types::{
NativeType, logical_float32, logical_float64, logical_int32,
};
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignature, TypeSignatureClass, Volatility,
Coercion, ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
use std::sync::Arc;

#[user_doc(
doc_section(label = "Math Functions"),
Expand Down Expand Up @@ -117,15 +119,74 @@ impl ScalarUDFImpl for RoundFunc {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(match arg_types[0].clone() {
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let input_field = &args.arg_fields[0];
let input_type = input_field.data_type();

// Get decimal_places from scalar_arguments
// If dp is not a constant scalar, we must keep the original scale because
// we can't determine a single output scale for varying per-row dp values.
let (decimal_places, dp_is_scalar): (i32, bool) =
if args.scalar_arguments.len() > 1 {
match args.scalar_arguments[1] {
Some(ScalarValue::Int32(Some(v))) => (*v, true),
Some(ScalarValue::Int64(Some(v))) => (*v as i32, true),
_ => (0, false), // dp is a column or null - can't determine scale
}
} else {
(0, true) // No dp argument means default to 0
};

// Calculate return type based on input type
// For decimals: reduce scale to decimal_places (reclaims precision for integer part)
// This matches Spark/DuckDB behavior where ROUND adjusts the scale
// BUT only if dp is a constant - otherwise keep original scale
let return_type = match input_type {
Float32 => Float32,
dt @ Decimal128(_, _)
| dt @ Decimal256(_, _)
| dt @ Decimal32(_, _)
| dt @ Decimal64(_, _) => dt,
Decimal32(precision, scale) => {
if dp_is_scalar {
let new_scale = (*scale).min(decimal_places.max(0) as i8);
Decimal32(*precision, new_scale)
} else {
Decimal32(*precision, *scale)
}
}
Decimal64(precision, scale) => {
if dp_is_scalar {
let new_scale = (*scale).min(decimal_places.max(0) as i8);
Decimal64(*precision, new_scale)
} else {
Decimal64(*precision, *scale)
}
}
Decimal128(precision, scale) => {
if dp_is_scalar {
let new_scale = (*scale).min(decimal_places.max(0) as i8);
Decimal128(*precision, new_scale)
} else {
Decimal128(*precision, *scale)
}
}
Decimal256(precision, scale) => {
if dp_is_scalar {
let new_scale = (*scale).min(decimal_places.max(0) as i8);
Decimal256(*precision, new_scale)
} else {
Decimal256(*precision, *scale)
}
}
_ => Float64,
})
};

Ok(Arc::new(Field::new(
self.name(),
return_type,
input_field.is_nullable(),
)))
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("use return_field_from_args instead")
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Expand All @@ -141,7 +202,6 @@ impl ScalarUDFImpl for RoundFunc {
&default_decimal_places
};

// Scalar fast path for float and decimal types - avoid array conversion overhead
if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) =
(&args.args[0], decimal_places)
{
Expand Down Expand Up @@ -169,27 +229,32 @@ impl ScalarUDFImpl for RoundFunc {
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
}
ScalarValue::Decimal128(Some(v), precision, scale) => {
let rounded = round_decimal(*v, *scale, dp)?;
// Reduce scale to reclaim integer precision
let new_scale = (*scale).min(dp.max(0) as i8);
let rounded = round_decimal(*v, *scale, new_scale, dp)?;
let scalar =
ScalarValue::Decimal128(Some(rounded), *precision, *scale);
ScalarValue::Decimal128(Some(rounded), *precision, new_scale);
Ok(ColumnarValue::Scalar(scalar))
}
ScalarValue::Decimal256(Some(v), precision, scale) => {
let rounded = round_decimal(*v, *scale, dp)?;
let new_scale = (*scale).min(dp.max(0) as i8);
let rounded = round_decimal(*v, *scale, new_scale, dp)?;
let scalar =
ScalarValue::Decimal256(Some(rounded), *precision, *scale);
ScalarValue::Decimal256(Some(rounded), *precision, new_scale);
Ok(ColumnarValue::Scalar(scalar))
}
ScalarValue::Decimal64(Some(v), precision, scale) => {
let rounded = round_decimal(*v, *scale, dp)?;
let new_scale = (*scale).min(dp.max(0) as i8);
let rounded = round_decimal(*v, *scale, new_scale, dp)?;
let scalar =
ScalarValue::Decimal64(Some(rounded), *precision, *scale);
ScalarValue::Decimal64(Some(rounded), *precision, new_scale);
Ok(ColumnarValue::Scalar(scalar))
}
ScalarValue::Decimal32(Some(v), precision, scale) => {
let rounded = round_decimal(*v, *scale, dp)?;
let new_scale = (*scale).min(dp.max(0) as i8);
let rounded = round_decimal(*v, *scale, new_scale, dp)?;
let scalar =
ScalarValue::Decimal32(Some(rounded), *precision, *scale);
ScalarValue::Decimal32(Some(rounded), *precision, new_scale);
Ok(ColumnarValue::Scalar(scalar))
}
_ => {
Expand All @@ -200,7 +265,12 @@ impl ScalarUDFImpl for RoundFunc {
}
}
} else {
round_columnar(&args.args[0], decimal_places, args.number_rows)
round_columnar(
&args.args[0],
decimal_places,
args.number_rows,
args.return_type(),
)
}
}

Expand Down Expand Up @@ -228,29 +298,31 @@ fn round_columnar(
value: &ColumnarValue,
decimal_places: &ColumnarValue,
number_rows: usize,
return_type: &DataType,
) -> Result<ColumnarValue> {
let value_array = value.to_array(number_rows)?;
let both_scalars = matches!(value, ColumnarValue::Scalar(_))
&& matches!(decimal_places, ColumnarValue::Scalar(_));

let arr: ArrayRef = match value_array.data_type() {
Float64 => {
let arr: ArrayRef = match (value_array.data_type(), return_type) {
(Float64, _) => {
let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>(
value_array.as_ref(),
decimal_places,
round_float::<f64>,
)?;
result as _
}
Float32 => {
(Float32, _) => {
let result = calculate_binary_math::<Float32Type, Int32Type, Float32Type, _>(
value_array.as_ref(),
decimal_places,
round_float::<f32>,
)?;
result as _
}
Decimal32(precision, scale) => {
(Decimal32(_, scale), Decimal32(precision, new_scale)) => {
// reduce scale to reclaim integer precision
let result = calculate_binary_decimal_math::<
Decimal32Type,
Int32Type,
Expand All @@ -259,13 +331,13 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
|v, dp| round_decimal(v, *scale, dp),
|v, dp| round_decimal(v, *scale, *new_scale, dp),
*precision,
*scale,
*new_scale,
)?;
result as _
}
Decimal64(precision, scale) => {
(Decimal64(_, scale), Decimal64(precision, new_scale)) => {
let result = calculate_binary_decimal_math::<
Decimal64Type,
Int32Type,
Expand All @@ -274,13 +346,13 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
|v, dp| round_decimal(v, *scale, dp),
|v, dp| round_decimal(v, *scale, *new_scale, dp),
*precision,
*scale,
*new_scale,
)?;
result as _
}
Decimal128(precision, scale) => {
(Decimal128(_, scale), Decimal128(precision, new_scale)) => {
let result = calculate_binary_decimal_math::<
Decimal128Type,
Int32Type,
Expand All @@ -289,13 +361,13 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
|v, dp| round_decimal(v, *scale, dp),
|v, dp| round_decimal(v, *scale, *new_scale, dp),
*precision,
*scale,
*new_scale,
)?;
result as _
}
Decimal256(precision, scale) => {
(Decimal256(_, scale), Decimal256(precision, new_scale)) => {
let result = calculate_binary_decimal_math::<
Decimal256Type,
Int32Type,
Expand All @@ -304,13 +376,13 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
|v, dp| round_decimal(v, *scale, dp),
|v, dp| round_decimal(v, *scale, *new_scale, dp),
*precision,
*scale,
*new_scale,
)?;
result as _
}
other => exec_err!("Unsupported data type {other:?} for function round")?,
(other, _) => exec_err!("Unsupported data type {other:?} for function round")?,
};

if both_scalars {
Expand All @@ -334,10 +406,11 @@ where

fn round_decimal<V: ArrowNativeTypeOp>(
value: V,
scale: i8,
input_scale: i8,
output_scale: i8,
decimal_places: i32,
) -> Result<V, ArrowError> {
let diff = i64::from(scale) - i64::from(decimal_places);
let diff = i64::from(input_scale) - i64::from(decimal_places);
if diff <= 0 {
return Ok(value);
}
Expand All @@ -358,7 +431,7 @@ fn round_decimal<V: ArrowNativeTypeOp>(

let factor = ten.pow_checked(diff).map_err(|_| {
ArrowError::ComputeError(format!(
"Overflow while rounding decimal with scale {scale} and decimal places {decimal_places}"
"Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}"
))
})?;

Expand All @@ -377,9 +450,40 @@ fn round_decimal<V: ArrowNativeTypeOp>(
})?;
}

quotient
.mul_checked(factor)
.map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into()))
// Determine how to scale the result based on output_scale vs computed scale
// computed_scale = max(0, min(input_scale, decimal_places))
let computed_scale = if decimal_places >= 0 {
(input_scale as i32).min(decimal_places).max(0) as i8
} else {
0
};

if output_scale == computed_scale {
// scale reduction, return quotient directly (or shifted for negative dp)
if decimal_places >= 0 {
Ok(quotient)
} else {
// For negative decimal_places, multiply by 10^(-decimal_places) to shift left
let neg_dp: u32 = (-decimal_places).try_into().map_err(|_| {
ArrowError::ComputeError(format!(
"Invalid negative decimal places: {decimal_places}"
))
})?;
let shift_factor = ten.pow_checked(neg_dp).map_err(|_| {
ArrowError::ComputeError(format!(
"Overflow computing shift factor for decimal places {decimal_places}"
))
})?;
quotient.mul_checked(shift_factor).map_err(|_| {
ArrowError::ComputeError("Overflow while rounding decimal".into())
})
}
} else {
// Keep original scale behavior: multiply back by factor
quotient.mul_checked(factor).map_err(|_| {
ArrowError::ComputeError("Overflow while rounding decimal".into())
})
}
}

#[cfg(test)]
Expand All @@ -397,12 +501,14 @@ mod test {
decimal_places: Option<ArrayRef>,
) -> Result<ArrayRef, DataFusionError> {
let number_rows = value.len();
let return_type = value.data_type().clone();
let value = ColumnarValue::Array(value);
let decimal_places = decimal_places
.map(ColumnarValue::Array)
.unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0))));

let result = super::round_columnar(&value, &decimal_places, number_rows)?;
let result =
super::round_columnar(&value, &decimal_places, number_rows, &return_type)?;
match result {
ColumnarValue::Array(array) => Ok(array),
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/decimal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ query TR
select arrow_typeof(round(173975140545.855, 2)),
round(173975140545.855, 2);
----
Decimal128(15, 3) 173975140545.86
Decimal128(15, 2) 173975140545.86

# smoke test for decimal parsing
query RT
Expand Down
Loading