diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 08a197a60eb75..b34e52d7f2e19 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -23,6 +23,7 @@ use arrow::{ util::bench_util::create_primitive_array, }; use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::signum; @@ -88,6 +89,51 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + // Scalar benchmarks (the optimization we added) + let scalar_f32_args = + vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(-42.5)))]; + let scalar_f32_arg_fields = + vec![Field::new("a", DataType::Float32, false).into()]; + let return_field_f32 = Field::new("f", DataType::Float32, false).into(); + + c.bench_function(&format!("signum f32 scalar: {size}"), |b| { + b.iter(|| { + black_box( + signum + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = + vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(-42.5)))]; + let scalar_f64_arg_fields = + vec![Field::new("a", DataType::Float64, false).into()]; + let return_field_f64 = Field::new("f", DataType::Float64, false).into(); + + c.bench_function(&format!("signum f64 scalar: {size}"), |b| { + b.iter(|| { + black_box( + signum + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } } diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index e217088c64c2e..8a3769a12f294 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -18,11 +18,12 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray}; +use arrow::array::AsArray; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{Result, exec_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -30,8 +31,6 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -use crate::utils::make_scalar_function; - #[user_doc( doc_section(label = "Math Functions"), description = r#"Returns the sign of a number. @@ -98,7 +97,53 @@ impl ScalarUDFImpl for SignumFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(signum, vec![])(&args.args) + let return_type = args.return_type().clone(); + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(&return_type, None); + } + + match scalar { + ScalarValue::Float64(Some(v)) => { + let result = if v == 0.0 { 0.0 } else { v.signum() }; + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result)))) + } + ScalarValue::Float32(Some(v)) => { + let result = if v == 0.0 { 0.0 } else { v.signum() }; + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result)))) + } + _ => { + internal_err!( + "Unexpected scalar type for signum: {:?}", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Float64 => Ok(ColumnarValue::Array(Arc::new( + array.as_primitive::().unary::<_, Float64Type>( + |x: f64| { + if x == 0.0 { 0.0 } else { x.signum() } + }, + ), + ))), + Float32 => Ok(ColumnarValue::Array(Arc::new( + array.as_primitive::().unary::<_, Float32Type>( + |x: f32| { + if x == 0.0 { 0.0 } else { x.signum() } + }, + ), + ))), + other => { + internal_err!("Unsupported data type {other:?} for function signum") + } + }, + } } fn documentation(&self) -> Option<&Documentation> { @@ -106,33 +151,6 @@ impl ScalarUDFImpl for SignumFunc { } } -/// signum SQL function -fn signum(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Float64 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float64Type>( - |x: f64| { - if x == 0_f64 { 0_f64 } else { x.signum() } - }, - ), - ) as ArrayRef), - - Float32 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float32Type>( - |x: f32| { - if x == 0_f32 { 0_f32 } else { x.signum() } - }, - ), - ) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function signum"), - } -} - #[cfg(test)] mod test { use std::sync::Arc;