diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index d1906a4bf0e01..c438351ab7bd8 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -21,7 +21,7 @@ use std::any::Any; use super::power::PowerFunc; -use crate::utils::calculate_binary_math; +use crate::utils::calculate_binary_math_numeric; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{ DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, @@ -272,56 +272,84 @@ impl ScalarUDFImpl for LogFunc { }; let value = value.to_array(args.number_rows)?; + // Get return type + let arg_types: Vec = + args.args.iter().map(|arg| arg.data_type()).collect(); + let out_type: &DataType = &self.return_type(&arg_types)?; + + // Safety: all `dec_scale.expect` calls below are infallible since the left argument + // is decimal array as per `calculate_binary_math` contract. let output: ArrayRef = match value.data_type() { DataType::Float16 => { - calculate_binary_math::( + calculate_binary_math_numeric::( &value, &base, - |value, base| Ok(value.log(base)), + |value, base, _| Ok(value.log(base)), + out_type, )? } DataType::Float32 => { - calculate_binary_math::( + calculate_binary_math_numeric::( &value, &base, - |value, base| Ok(value.log(base)), + |value, base, _| Ok(value.log(base)), + out_type, )? } DataType::Float64 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } - DataType::Decimal32(_, scale) => { - calculate_binary_math::( + calculate_binary_math_numeric::( &value, &base, - |value, base| log_decimal32(value, *scale, base), + |value, base, _| Ok(value.log(base)), + out_type, )? } - DataType::Decimal64(_, scale) => { - calculate_binary_math::( + DataType::Decimal32(_, _) => { + calculate_binary_math_numeric::( &value, &base, - |value, base| log_decimal64(value, *scale, base), + |value, base, dec_scale| { + log_decimal32(value, dec_scale.expect("value is decimal").1, base) + }, + out_type, )? } - DataType::Decimal128(_, scale) => { - calculate_binary_math::( + DataType::Decimal64(_, _) => { + calculate_binary_math_numeric::( &value, &base, - |value, base| log_decimal128(value, *scale, base), - )? - } - DataType::Decimal256(_, scale) => { - calculate_binary_math::( - &value, - &base, - |value, base| log_decimal256(value, *scale, base), + |value, base, dec_scale| { + log_decimal64(value, dec_scale.expect("value is decimal").1, base) + }, + out_type, )? } + DataType::Decimal128(_, _) => calculate_binary_math_numeric::< + Decimal128Type, + Float64Type, + Float64Type, + _, + >( + &value, + &base, + |value, base, dec_scale| { + log_decimal128(value, dec_scale.expect("value is decimal").1, base) + }, + out_type, + )?, + DataType::Decimal256(_, _) => calculate_binary_math_numeric::< + Decimal256Type, + Float64Type, + Float64Type, + _, + >( + &value, + &base, + |value, base, dec_scale| { + log_decimal256(value, dec_scale.expect("value is decimal").1, base) + }, + out_type, + )?, other => { return exec_err!("Unsupported data type {other:?} for function log"); } diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 489c59aa3d6fa..5cca7d5452a1d 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -20,7 +20,9 @@ use std::any::Any; use super::log::LogFunc; -use crate::utils::{calculate_binary_decimal_math, calculate_binary_math}; +use crate::utils::{ + calculate_binary_math, calculate_binary_math_decimal, calculate_binary_math_numeric, +}; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::i256; use arrow::datatypes::{ @@ -437,25 +439,30 @@ impl ScalarUDFImpl for PowerFunc { return pow_decimal_with_float_fallback(&base, exponent, args.number_rows); } + let out_type = base.data_type(); + // Safety: all `dec_scale.expect` calls below are infallible since the left argument + // is decimal array as per `calculate_binary_math` contract. let arr: ArrayRef = match (base.data_type(), exponent.data_type()) { (DataType::Float64, DataType::Float64) => { - calculate_binary_math::( + calculate_binary_math_numeric::( &base, exponent, - |b, e| Ok(f64::powf(b, e)), + |b, e, _| Ok(f64::powf(b, e)), + out_type, )? } - (DataType::Decimal32(precision, scale), DataType::Int64) => { - calculate_binary_decimal_math::( + (DataType::Decimal32(_, _), DataType::Int64) => { + calculate_binary_math_decimal::( &base, exponent, - |b, e| pow_decimal_int(b, *scale, e), - *precision, - *scale, + |b, e, dec_scale| { + pow_decimal_int(b, dec_scale.expect("left is decimal").1, e) + }, + out_type, )? } - (DataType::Decimal32(precision, scale), DataType::Float64) => { - calculate_binary_decimal_math::< + (DataType::Decimal32(_, _), DataType::Float64) => { + calculate_binary_math_decimal::< Decimal32Type, Float64Type, Decimal32Type, @@ -463,22 +470,24 @@ impl ScalarUDFImpl for PowerFunc { >( &base, exponent, - |b, e| pow_decimal_float(b, *scale, e), - *precision, - *scale, + |b, e, dec_scale| { + pow_decimal_float(b, dec_scale.expect("left is decimal").1, e) + }, + out_type, )? } - (DataType::Decimal64(precision, scale), DataType::Int64) => { - calculate_binary_decimal_math::( + (DataType::Decimal64(_, _), DataType::Int64) => { + calculate_binary_math_decimal::( &base, exponent, - |b, e| pow_decimal_int(b, *scale, e), - *precision, - *scale, + |b, e, dec_scale| { + pow_decimal_int(b, dec_scale.expect("left is decimal").1, e) + }, + out_type, )? } - (DataType::Decimal64(precision, scale), DataType::Float64) => { - calculate_binary_decimal_math::< + (DataType::Decimal64(_, _), DataType::Float64) => { + calculate_binary_math_decimal::< Decimal64Type, Float64Type, Decimal64Type, @@ -486,13 +495,14 @@ impl ScalarUDFImpl for PowerFunc { >( &base, exponent, - |b, e| pow_decimal_float(b, *scale, e), - *precision, - *scale, + |b, e, dec_scale| { + pow_decimal_float(b, dec_scale.expect("left is decimal").1, e) + }, + out_type, )? } - (DataType::Decimal128(precision, scale), DataType::Int64) => { - calculate_binary_decimal_math::< + (DataType::Decimal128(_, _), DataType::Int64) => { + calculate_binary_math_decimal::< Decimal128Type, Int64Type, Decimal128Type, @@ -500,13 +510,14 @@ impl ScalarUDFImpl for PowerFunc { >( &base, exponent, - |b, e| pow_decimal_int(b, *scale, e), - *precision, - *scale, + |b, e, dec_scale| { + pow_decimal_int(b, dec_scale.expect("left is decimal").1, e) + }, + out_type, )? } - (DataType::Decimal128(precision, scale), DataType::Float64) => { - calculate_binary_decimal_math::< + (DataType::Decimal128(_, _), DataType::Float64) => { + calculate_binary_math_decimal::< Decimal128Type, Float64Type, Decimal128Type, @@ -514,13 +525,14 @@ impl ScalarUDFImpl for PowerFunc { >( &base, exponent, - |b, e| pow_decimal_float(b, *scale, e), - *precision, - *scale, + |b, e, dec_scale| { + pow_decimal_float(b, dec_scale.expect("left is decimal").1, e) + }, + out_type, )? } - (DataType::Decimal256(precision, scale), DataType::Int64) => { - calculate_binary_decimal_math::< + (DataType::Decimal256(_, _), DataType::Int64) => { + calculate_binary_math_decimal::< Decimal256Type, Int64Type, Decimal256Type, @@ -528,13 +540,14 @@ impl ScalarUDFImpl for PowerFunc { >( &base, exponent, - |b, e| pow_decimal256_int(b, *scale, e), - *precision, - *scale, + |b, e, dec_scale| { + pow_decimal256_int(b, dec_scale.expect("left is decimal").1, e) + }, + out_type, )? } - (DataType::Decimal256(precision, scale), DataType::Float64) => { - calculate_binary_decimal_math::< + (DataType::Decimal256(_, _), DataType::Float64) => { + calculate_binary_math_decimal::< Decimal256Type, Float64Type, Decimal256Type, @@ -542,9 +555,10 @@ impl ScalarUDFImpl for PowerFunc { >( &base, exponent, - |b, e| pow_decimal256_float(b, *scale, e), - *precision, - *scale, + |b, e, dec_scale| { + pow_decimal256_float(b, dec_scale.expect("left is decimal").1, e) + }, + out_type, )? } (base_type, exp_type) => { diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index de70788128b88..5009c2e942000 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -17,9 +17,9 @@ use std::any::Any; -use crate::utils::{calculate_binary_decimal_math, calculate_binary_math}; +use crate::utils::{calculate_binary_math_decimal, calculate_binary_math_numeric}; -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef}; use arrow::datatypes::DataType::{ Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, }; @@ -173,25 +173,33 @@ fn round_columnar( let both_scalars = matches!(value, ColumnarValue::Scalar(_)) && matches!(decimal_places, ColumnarValue::Scalar(_)); + let out_type = value_array.data_type(); + + // Safety: all `dec_scale.expect` calls below are infallible since the left argument + // is decimal array as per `calculate_binary_math` contract. let arr: ArrayRef = match value_array.data_type() { Float64 => { - let result = calculate_binary_math::( - value_array.as_ref(), - decimal_places, - round_float::, - )?; + let result = + calculate_binary_math_numeric::( + value_array.as_ref(), + decimal_places, + |l, r, _| round_float::(l, r), + out_type, + )?; result as _ } Float32 => { - let result = calculate_binary_math::( - value_array.as_ref(), - decimal_places, - round_float::, - )?; + let result = + calculate_binary_math_numeric::( + value_array.as_ref(), + decimal_places, + |l, r, _| round_float::(l, r), + out_type, + )?; result as _ } - Decimal32(precision, scale) => { - let result = calculate_binary_decimal_math::< + Decimal32(_, _) => { + let result = calculate_binary_math_decimal::< Decimal32Type, Int32Type, Decimal32Type, @@ -199,14 +207,15 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), - *precision, - *scale, + |v, dp, dec_scale| { + round_decimal(v, dec_scale.expect("value is decimal").1, dp) + }, + out_type, )?; result as _ } - Decimal64(precision, scale) => { - let result = calculate_binary_decimal_math::< + Decimal64(_, _) => { + let result = calculate_binary_math_decimal::< Decimal64Type, Int32Type, Decimal64Type, @@ -214,14 +223,15 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), - *precision, - *scale, + |v, dp, dec_scale| { + round_decimal(v, dec_scale.expect("value is decimal").1, dp) + }, + out_type, )?; result as _ } - Decimal128(precision, scale) => { - let result = calculate_binary_decimal_math::< + Decimal128(_, _) => { + let result = calculate_binary_math_decimal::< Decimal128Type, Int32Type, Decimal128Type, @@ -229,14 +239,15 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), - *precision, - *scale, + |v, dp, dec_scale| { + round_decimal(v, dec_scale.expect("value is decimal").1, dp) + }, + out_type, )?; result as _ } - Decimal256(precision, scale) => { - let result = calculate_binary_decimal_math::< + Decimal256(_, _) => { + let result = calculate_binary_math_decimal::< Decimal256Type, Int32Type, Decimal256Type, @@ -244,9 +255,10 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), - *precision, - *scale, + |v, dp, dec_scale| { + round_decimal(v, dec_scale.expect("value is decimal").1, dp) + }, + out_type, )?; result as _ } diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index e4980728b18a0..2fcb851c5fe58 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -15,11 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; -use arrow::compute::try_binary; -use arrow::datatypes::{DataType, DecimalType}; +use arrow::array::{ + Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray, PrimitiveArray, + PrimitiveBuilder, +}; +use arrow::compute::{DecimalCast, try_binary}; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, +}; use arrow::error::ArrowError; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_datafusion_err}; use datafusion_expr::ColumnarValue; use datafusion_expr::function::Hint; use std::sync::Arc; @@ -128,6 +133,8 @@ where /// - `R`: Right array primitive type /// - `O`: Output array primitive type /// - `F`: Functor computing `fun(l: L, r: R) -> Result` +/// +/// Deprecated. use calculate_binary_math_numeric instead pub fn calculate_binary_math( left: &dyn Array, right: &ColumnarValue, @@ -167,6 +174,394 @@ where Ok(Arc::new(result) as _) } +/// Helper to extract a native value from a ScalarValue, providing a DataFusionError +fn try_from_scalar(value: ScalarValue) -> Result +where + T: ArrowPrimitiveType, + T::Native: TryFrom, +{ + // Construct an error string beforehand to avoid extra cloning + let err_str = format!( + "Cannot convert scalar value {} of type {} to {}", + value, + value.data_type(), + T::DATA_TYPE + ); + T::Native::try_from(value).map_err(|_| DataFusionError::Execution(err_str)) +} + +/// Extract precision and scale from a decimal DataType, or None if not decimal +fn get_decimal_precision_scale(data_type: &DataType) -> Option<(u8, i8)> { + match data_type { + DataType::Decimal32(precision, scale) => Some((*precision, *scale)), + DataType::Decimal64(precision, scale) => Some((*precision, *scale)), + DataType::Decimal128(precision, scale) => Some((*precision, *scale)), + DataType::Decimal256(precision, scale) => Some((*precision, *scale)), + _ => None, + } +} + +/// Adapter for `arrow::compute::rescale_decimal` to rescale an array of values +fn rescale_array( + input: &PrimitiveArray, + output_precision: u8, + output_scale: i8, +) -> Result> +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + let input_precision = input.precision(); + let input_scale = input.scale(); + let mut builder = PrimitiveBuilder::::with_capacity(input.len()); + for i in 0..input.len() { + if input.is_null(i) { + builder.append_null(); + } else { + let value = input.value(i); + // Change scale of one value using arrow casting function + match arrow::compute::rescale_decimal::( + value, + input_precision, + input_scale, + output_precision, + output_scale, + ) { + Some(rescaled_value) => builder.append_value(rescaled_value), + None => builder.append_null(), + } + } + } + let result: PrimitiveArray = builder + .finish() + .with_precision_and_scale(output_precision, output_scale)?; + Ok(result) +} + +/// Rescales an array to the given precision and scale if it is a decimal +/// Returns an execution error otherwise +fn rescale_decimal_array( + input: &dyn Array, + precision: u8, + scale: i8, +) -> Result { + match input.data_type() { + DataType::Decimal128(_, _) => Ok(Arc::new(rescale_array::( + input.as_primitive::(), + precision, + scale, + )?)), + DataType::Decimal256(_, _) => Ok(Arc::new(rescale_array::( + input.as_primitive::(), + precision, + scale, + )?)), + DataType::Decimal32(_, _) => Ok(Arc::new(rescale_array::( + input.as_primitive::(), + precision, + scale, + )?)), + DataType::Decimal64(_, _) => Ok(Arc::new(rescale_array::( + input.as_primitive::(), + precision, + scale, + )?)), + _ => Err(exec_datafusion_err!( + "Failed to rescale value of non-decimal type {}", + input.data_type() + )), + } +} + +/// Rescales a scalar value to the given precision and scale if it is a decimal +/// Returns an execution error otherwise +fn rescale_decimal_scalar( + input: &ScalarValue, + precision: u8, + scale: i8, +) -> Result { + match input { + ScalarValue::Decimal128(_, _, _) => { + input.cast_to(&DataType::Decimal128(precision, scale)) + } + ScalarValue::Decimal256(_, _, _) => { + input.cast_to(&DataType::Decimal256(precision, scale)) + } + ScalarValue::Decimal32(_, _, _) => { + input.cast_to(&DataType::Decimal32(precision, scale)) + } + ScalarValue::Decimal64(_, _, _) => { + input.cast_to(&DataType::Decimal64(precision, scale)) + } + _ => Err(exec_datafusion_err!( + "Failed to rescale value of non-decimal type {}", + input.data_type() + )), + } +} + +/// Cast an array to the given output data type using `arrow::compute::cast` +fn cast_array_to( + input: Arc>, + out_type: &DataType, +) -> Result>> +where + O: ArrowPrimitiveType, +{ + if input.data_type() == out_type { + // Return as is + Ok(input) + } else { + // cast to output data type, performing rescaling + let casted_result = arrow::compute::cast(input.as_ref(), out_type)?; + casted_result + .as_primitive_opt::() + .ok_or_else(|| { + exec_datafusion_err!("Failed to cast array to type {}", O::DATA_TYPE) + }) + .map(|arr| Arc::new(arr.clone())) + } +} + +// Shorthand type for decimal precision and scale information +pub type DecScale = Option<(u8, i8)>; + +/// A helper function for internal use. +/// Computes a binary math function for input arrays using a specified function. +/// +/// Handles left `L` and right `R` Arrow types and perform decimal rescaling on them, if needed. +/// +/// If a left type is a decimal, then it's scale is passed as a `DecScale` parameter to the functor, +/// otherwise None is passed. +fn calculate_binary_math_impl( + left: &dyn Array, + right: &ColumnarValue, + fun: F, +) -> Result<(Arc>, DecScale)> +where + L: ArrowPrimitiveType, + R: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(L::Native, R::Native, DecScale) -> Result, + R::Native: TryFrom, +{ + log::debug!( + "calculate_binary_math_impl called with left {left:?} and right {right:?}, types {} x {} -> {}", + L::DATA_TYPE, + R::DATA_TYPE, + O::DATA_TYPE + ); + + let left = left.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("Failed to cast left to type {}", L::DATA_TYPE) + })?; + let right = right.cast_to(&R::DATA_TYPE, None)?; + + let result = match right { + ColumnarValue::Array(right) => { + let right = right.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("Failed to cast right to type {}", R::DATA_TYPE) + })?; + + // Four possible combinations of decimal and non-decimal inputs + match ( + get_decimal_precision_scale(left.data_type()), + get_decimal_precision_scale(right.data_type()), + ) { + ( + Some((left_precision, left_scale)), + Some((right_precision, right_scale)), + ) => { + log::debug!( + "calculate_binary_math: rescaling {left_precision}, {left_scale} and {right_precision}, {right_scale}" + ); + + // Scale both arguments to a common scale (choose the smaller to avoid overflows) + if left_scale < right_scale { + let right_scaled = + rescale_decimal_array(right, left_precision, left_scale)?; + let right_scaled = + right_scaled.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!( + "Failed to cast right array to type {}", + R::DATA_TYPE + ) + })?; + log::debug!( + "calculate_binary_math: rescaled array right {right_scaled:?}" + ); + let interim = + try_binary::<_, _, _, O>(left, right_scaled, |l, r| { + fun(l, r, Some((left_precision, left_scale))) + })?; + (Arc::new(interim) as _, Some((left_precision, left_scale))) + } else { + let left_scaled = + rescale_decimal_array(left, right_precision, right_scale)?; + let left_scaled = + left_scaled.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!( + "Failed to cast left array to type {}", + L::DATA_TYPE + ) + })?; + log::debug!( + "calculate_binary_math: rescaled array left {left_scaled:?}" + ); + let interim = + try_binary::<_, _, _, O>(left_scaled, right, |l, r| { + fun(l, r, Some((right_precision, right_scale))) + })?; + (Arc::new(interim) as _, Some((right_precision, right_scale))) + } + } + (Some(left_dec_scale), None) => { + let interim = try_binary::<_, _, _, O>(left, right, |l, r| { + fun(l, r, Some(left_dec_scale)) + })?; + (Arc::new(interim) as _, Some(left_dec_scale)) + } + // Two last patterns together, when left is not decimal + (None, opt_right_precision_and_scale) => { + let interim = try_binary::<_, _, _, O>(left, right, |l, r| { + fun(l, r, opt_right_precision_and_scale) + })?; + (Arc::new(interim) as _, opt_right_precision_and_scale) + } + } + } + ColumnarValue::Scalar(scalar) if scalar.is_null() => { + // Null scalar is castable to any numeric, creating a non-null expression. + // Provide null array explicitly to make result null + let interim = PrimitiveArray::::new_null(left.len()); + (Arc::new(interim) as _, None) + } + ColumnarValue::Scalar(right) => { + // Four possible combinations of decimal and non-decimal inputs + match ( + get_decimal_precision_scale(left.data_type()), + get_decimal_precision_scale(&right.data_type()), + ) { + ( + Some((left_precision, left_scale)), + Some((right_precision, right_scale)), + ) => { + if left_scale < right_scale { + let right_scaled = + rescale_decimal_scalar(&right, left_precision, left_scale)?; + let right_native = try_from_scalar::(right_scaled)?; + log::debug!( + "calculate_binary_math: rescaled scalar right {right_native:?}" + ); + let interim = left.try_unary::<_, O, _>(|l| { + fun(l, right_native, Some((left_precision, left_scale))) + })?; + (Arc::new(interim) as _, Some((left_precision, left_scale))) + } else { + let left_scaled = + rescale_decimal_array(left, right_precision, right_scale)?; + let left_scaled = + left_scaled.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!( + "Failed to cast left array to type {}", + L::DATA_TYPE + ) + })?; + log::debug!( + "calculate_binary_math: rescaled array left {left_scaled:?}" + ); + let right_native = try_from_scalar::(right.clone())?; + let interim = left_scaled.try_unary::<_, O, _>(|l| { + fun(l, right_native, Some((right_precision, right_scale))) + })?; + (Arc::new(interim) as _, Some((right_precision, right_scale))) + } + } + (Some((left_precision, left_scale)), None) => { + let right_native = try_from_scalar::(right.clone())?; + let interim = left.try_unary::<_, O, _>(|l| { + fun(l, right_native, Some((left_precision, left_scale))) + })?; + (Arc::new(interim) as _, Some((left_precision, left_scale))) + } + // Two last patterns together, when left is not decimal + (None, opt_right_precision_and_scale) => { + let right_native = try_from_scalar::(right.clone())?; + let interim = left.try_unary::<_, O, _>(|l| { + fun(l, right_native, opt_right_precision_and_scale) + })?; + (Arc::new(interim) as _, opt_right_precision_and_scale) + } + } + } + }; + log::debug!("calculate_binary_math: result {result:?}"); + Ok(result) +} + +/// Computes a binary math function for input arrays using a specified function +/// with any left `L` and right `R` Arrow types, and result of a decimal type `O`. +/// +/// Functor `F` computes one operation for Arrow types. +/// If a left type is a decimal, then it's scale is passed as a `DecScale` parameter to the functor, +/// otherwise None is passed. +/// +pub fn calculate_binary_math_decimal( + left: &dyn Array, + right: &ColumnarValue, + fun: F, + out_type: &DataType, +) -> Result>> +where + L: ArrowPrimitiveType, + R: ArrowPrimitiveType, + O: ArrowPrimitiveType + DecimalType, + F: Fn(L::Native, R::Native, DecScale) -> Result, + R::Native: TryFrom, +{ + let (interim, scale_opt) = + calculate_binary_math_impl::(left, right, fun)?; + + // Perform rescaling having `interim` as a decimal array + let result: Arc> = + if let Some((out_precision, out_scale)) = scale_opt { + // Apply scale and cast + let interim = Arc::unwrap_or_clone(interim) + .with_precision_and_scale(out_precision, out_scale)?; + cast_array_to(Arc::new(interim), out_type)? + } else { + // Just cast + cast_array_to(interim, out_type)? + }; + log::debug!("calculate_binary_math_decimal: result {result:?} out_type={out_type:?}"); + Ok(result) +} + +/// Computes a binary math function for input arrays using a specified function +/// with any left `L` and right `R` Arrow types, and result of a non-decimal type `O`. +/// +/// Functor `F` computes one operation for Arrow types. +/// +pub fn calculate_binary_math_numeric( + left: &dyn Array, + right: &ColumnarValue, + fun: F, + out_type: &DataType, +) -> Result>> +where + L: ArrowPrimitiveType, + R: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(L::Native, R::Native, DecScale) -> Result, + R::Native: TryFrom, +{ + let (interim, _) = calculate_binary_math_impl::(left, right, fun)?; + // Ignore provided decimal scale and precision, just cast to output type + let result = cast_array_to(interim, out_type)?; + log::debug!("calculate_binary_math_numeric: result {result:?}"); + Ok(result) +} + /// Computes a binary math function for input arrays using a specified function /// and apply rescaling to given precision and scale. /// Generic types: @@ -174,6 +569,8 @@ where /// - `R`: Right array primitive type /// - `O`: Output array decimal type /// - `F`: Functor computing `fun(l: L, r: R) -> Result` +/// +/// Deprecated. use calculate_binary_math_decimal instead pub fn calculate_binary_decimal_math( left: &dyn Array, right: &ColumnarValue, @@ -363,12 +760,21 @@ pub mod test { }; } - use arrow::datatypes::DataType; use itertools::Either; pub(crate) use test_function; use super::*; + use arrow::array::{Decimal128Array, Float64Array, Int64Array, PrimitiveArray}; + use arrow::datatypes::{ + DECIMAL128_MAX_PRECISION, DataType, Decimal128Type, Float64Type, Int64Type, + }; + #[cfg(test)] + #[ctor::ctor] + fn init() { + // Enable RUST_LOG logging configuration for test + let _ = env_logger::try_init(); + } #[test] fn string_to_int_type() { let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap(); @@ -509,4 +915,350 @@ pub mod test { } } } + + // Test constant + const LARGE: i128 = 2i128.pow(70); + + #[test] + fn test_calculate_binary_math_array_scalar_int64() { + let left = Int64Array::from(vec![0, 12, i64::MAX - 42]); + let right = ColumnarValue::Scalar(ScalarValue::Int64(Some(42))); + let result = calculate_binary_math::( + &left, + &right, + |x, y| Ok(x + y), + ) + .expect("calculate"); + assert_eq!(result.len(), 3); + assert_eq!(result.value(0), 42); + assert_eq!(result.value(1), 54); + assert_eq!(result.value(2), i64::MAX); + } + + #[test] + fn test_calculate_binary_math_array_array_int64() { + let left = Int64Array::from(vec![0, 12, i64::MAX - 42]); + let right = ColumnarValue::Array(Arc::new(Int64Array::from(vec![42, 42, 42]))); + let result = calculate_binary_math::( + &left, + &right, + |x, y| Ok(x + y), + ) + .expect("calculate"); + assert_eq!(result.len(), 3); + assert_eq!(result.value(0), 42); + assert_eq!(result.value(1), 54); + assert_eq!(result.value(2), i64::MAX); + } + + #[test] + fn test_calculate_binary_decimal128_array_decimal_scalar_unscaled() { + for precision in [10, 20, DECIMAL128_MAX_PRECISION] { + let left = Decimal128Array::from(vec![0, 12, LARGE]) + .with_precision_and_scale(precision, 0) + .unwrap(); + let right = + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(42), precision, 0)); + let result = + calculate_binary_math_decimal::< + Decimal128Type, + Decimal128Type, + Decimal128Type, + _, + >(&left, &right, |x, y, _| Ok(x + y), left.data_type()) + .expect("calculate"); + _check_calculate_binary_result(result, left.data_type().clone()); + } + } + + #[test] + fn test_calculate_binary_decimal128_array_decimal_array_unscaled() { + for precision in [10, 20, DECIMAL128_MAX_PRECISION] { + let left = Decimal128Array::from(vec![0, 12, LARGE]) + .with_precision_and_scale(precision, 0) + .unwrap(); + let right = ColumnarValue::Array(Arc::new( + Decimal128Array::from(vec![Some(42), Some(42), Some(42)]) + .with_precision_and_scale(precision, 0) + .unwrap(), + )); + let result = + calculate_binary_math_decimal::< + Decimal128Type, + Decimal128Type, + Decimal128Type, + _, + >(&left, &right, |x, y, _| Ok(x + y), left.data_type()) + .expect("calculate"); + _check_calculate_binary_result(result, left.data_type().clone()); + } + } + + #[test] + fn test_calculate_binary_decimal128_array_decimal_array_same_scale() { + for precision in [10, 20, DECIMAL128_MAX_PRECISION] { + // 0, 12, 2**70 + let left = Decimal128Array::from(vec![0, 12000, LARGE * 1000]) + .with_precision_and_scale(precision, 3) + .unwrap(); + // 42 + let right = ColumnarValue::Array(Arc::new( + Decimal128Array::from(vec![Some(42000), Some(42000), Some(42000)]) + .with_precision_and_scale(precision, 3) + .unwrap(), + )); + let result = + calculate_binary_math_decimal::< + Decimal128Type, + Decimal128Type, + Decimal128Type, + _, + >(&left, &right, |x, y, _| Ok(x + y), left.data_type()) + .expect("calculate"); + _check_calculate_binary_result(result, left.data_type().clone()); + } + } + + #[test] + fn test_calculate_binary_decimal128_array_decimal_array_different_scale() { + for precision in [10, 20, DECIMAL128_MAX_PRECISION] { + // 0, 12, 2**70 + let left = Decimal128Array::from(vec![0, 12000, LARGE * 1000]) + .with_precision_and_scale(precision, 3) + .unwrap(); + // 42 + let right = ColumnarValue::Array(Arc::new( + Decimal128Array::from(vec![Some(4200000), Some(4200000), Some(4200000)]) + .with_precision_and_scale(precision, 5) + .unwrap(), + )); + let result = + calculate_binary_math_decimal::< + Decimal128Type, + Decimal128Type, + Decimal128Type, + _, + >(&left, &right, |x, y, _| Ok(x + y), left.data_type()) + .expect("calculate"); + _check_calculate_binary_result(result, left.data_type().clone()); + } + } + + #[test] + fn test_calculate_binary_decimal128_array_decimal_literal_different_scale() { + for precision in [10, 20, DECIMAL128_MAX_PRECISION] { + // 0, 12, 2**70 + let left = Decimal128Array::from(vec![0, 12000, LARGE * 1000]) + .with_precision_and_scale(precision, 3) + .unwrap(); + // 42 + let right = ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(4200000), + precision, + 5, + )); + + let result = calculate_binary_math_decimal::< + Decimal128Type, + Decimal128Type, + Decimal128Type, + _, + >( + &left, + &right, + |x, y, dec_scale| { + let op_precision = dec_scale.unwrap().0; + let op_scale = dec_scale.unwrap().1; + assert_eq!(op_precision, precision); + assert_eq!(op_scale, 3); + Ok(x + y) + }, + left.data_type(), + ) + .expect("calculate"); + _check_calculate_binary_result(result, left.data_type().clone()); + } + } + + #[test] + fn test_calculate_binary_decimal128_array_float_array() { + for precision in [10, 20, DECIMAL128_MAX_PRECISION] { + // 0, 12, 2**70 + let left = Decimal128Array::from(vec![0, 12000, LARGE * 1000]) + .with_precision_and_scale(precision, 3) + .unwrap(); + // 42 + let right = ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 42.0, 42.0, 42.0, + ]))); + let result = calculate_binary_math_decimal::< + Decimal128Type, + Float64Type, + Decimal128Type, + _, + >( + &left, + &right, + |x, y, dec_scale| { + let scale = dec_scale.unwrap().1; + // To produce a resulting value, one should scale right numeric to match decimal scale + // 0 + 42 + // 12000 + 42 + // 1180591620717411303424000 + 42 + Ok(x + i128::from(y.round() as i64) + * i128::from(10).pow(scale as u32)) + }, + left.data_type(), + ) + .expect("calculate"); + _check_calculate_binary_result(result, left.data_type().clone()); + } + } + + #[test] + fn test_calculate_binary_decimal128_array_float_literal() { + for precision in [10, 20, DECIMAL128_MAX_PRECISION] { + // 0, 12, 2**70 + let left = Decimal128Array::from(vec![0, 12000, LARGE * 1000]) + .with_precision_and_scale(precision, 3) + .unwrap(); + // 42 + let right = ColumnarValue::Scalar(ScalarValue::Float64(Some(42.0))); + let result = calculate_binary_math_decimal::< + Decimal128Type, + Float64Type, + Decimal128Type, + _, + >( + &left, + &right, + |x, y, dec_scale| { + let scale = dec_scale.unwrap().1; + // To produce a resulting value, one should scale right numeric to match decimal scale + // 0 + 42 + // 12000 + 42 + // 1180591620717411303424000 + 42 + Ok(x + i128::from(y.round() as i64) + * i128::from(10).pow(scale as u32)) + }, + left.data_type(), + ) + .expect("calculate"); + _check_calculate_binary_result(result, left.data_type().clone()); + } + } + + #[test] + fn test_calculate_binary_decimal128_array_float_array_float_result() { + for precision in [10, 20, DECIMAL128_MAX_PRECISION] { + // 0, 12, 2**70 + let left = Decimal128Array::from(vec![0, 12000, LARGE * 1000]) + .with_precision_and_scale(precision, 3) + .unwrap(); + // 42 + let right = ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 42.0, 42.0, 42.0, + ]))); + let result = calculate_binary_math_numeric::< + Decimal128Type, + Float64Type, + Float64Type, + _, + >( + &left, + &right, + |x, y, dec_scale| { + let scale = dec_scale.unwrap().1; + Ok((1 + x / 10i128.pow(scale as u32)).ilog2() as f64 + y) + }, + &DataType::Float64, + ) + .expect("calculate"); + assert_eq!(*result.data_type(), DataType::Float64); + assert_eq!(result.len(), 3); + assert!((result.value(0) - 42.0).abs() < f64::EPSILON); + assert!((result.value(1) - 3.0 - 42.0).abs() < f64::EPSILON); + assert!((result.value(2) - 70.0 - 42.0).abs() < f64::EPSILON); + } + } + + #[test] + fn test_calculate_binary_decimal128_array_float_array_float_literal_result() { + for precision in [10, 20, DECIMAL128_MAX_PRECISION] { + // 0, 12, 2**70 + let left = Decimal128Array::from(vec![0, 12000, LARGE * 1000]) + .with_precision_and_scale(precision, 3) + .unwrap(); + // 42 + let right = ColumnarValue::Scalar(ScalarValue::Float64(Some(42.0))); + let result = calculate_binary_math_numeric::< + Decimal128Type, + Float64Type, + Float64Type, + _, + >( + &left, + &right, + |x, y, dec_scale| { + // a random calculation to capture scale usage in the test + let scale = dec_scale.unwrap().1; + Ok((1 + x / 10i128.pow(scale as u32)).ilog2() as f64 + y) + }, + &DataType::Float64, + ) + .expect("calculate"); + + assert_eq!(*result.data_type(), DataType::Float64); + assert_eq!(result.len(), 3); + assert!((result.value(0) - 42.0).abs() < f64::EPSILON); + assert!((result.value(1) - 3.0 - 42.0).abs() < f64::EPSILON); + assert!((result.value(2) - 70.0 - 42.0).abs() < f64::EPSILON); + } + } + + // Test helper to verify against a known result + fn _check_calculate_binary_result( + result: Arc>, + expected_type: DataType, + ) { + log::debug!( + "checking result: {:?} of type: {}", + result, + result.data_type() + ); + let (_precision, scale) = + get_decimal_precision_scale(&expected_type).expect("decimal type"); + assert_eq!(*result.data_type(), expected_type); + let ten_scaled = i128::from(10).pow(scale as u32); + assert_eq!(result.len(), 3); + assert_eq!(result.value(0), 42 * ten_scaled); + assert_eq!(result.value(1), (42 + 12) * ten_scaled); + assert_eq!(result.value(2), (i128::from(42) + LARGE) * ten_scaled); + } + + #[test] + fn test_rescale_array_down() { + let input = Decimal128Array::from(vec![0, 1200000, 4200000]) + .with_precision_and_scale(20, 5) + .unwrap(); + let result = rescale_array(&input, 20, 0); + assert!(result.is_ok()); + let expected = Decimal128Array::from(vec![0, 12, 42]) + .with_precision_and_scale(20, 0) + .unwrap(); + assert_eq!(result.unwrap(), expected); + } + + #[test] + fn test_rescale_array_up() { + let input = Decimal128Array::from(vec![0, 12, 42]) + .with_precision_and_scale(20, 0) + .unwrap(); + let result = rescale_array(&input, 20, 5); + assert!(result.is_ok()); + let expected = Decimal128Array::from(vec![0, 1200000, 4200000]) + .with_precision_and_scale(20, 5) + .unwrap(); + assert_eq!(result.unwrap(), expected); + } } diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index f53f4939299c5..a00d8be8c7cf3 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -1165,6 +1165,29 @@ SELECT power(2.0, null) ---- NULL +# Different scales for arguments +query R +SELECT power(100::decimal(38, 5), 3::decimal(38, 2)) +---- +1000000 + +query R +SELECT power(100::decimal(72, 10), 3::decimal(38, 2)) +---- +1000000 + +query R +SELECT power(100::decimal(38, 5), 3) +---- +1000000 + +query R +SELECT power(arrow_cast(100, 'Decimal128(38, 5)'), 3) +---- +1000000 + + + # Array variants of power function query RR rowsort SELECT distinct c1*100000, power(c1*100000, 2) from decimal_simple;