Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions datafusion/functions/benches/signum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow::{
util::bench_util::create_primitive_array,
};
use criterion::{Criterion, criterion_group, criterion_main};
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use datafusion_functions::math::signum;
Expand Down Expand Up @@ -88,6 +89,51 @@ fn criterion_benchmark(c: &mut Criterion) {
)
})
});

// Scalar benchmarks (the optimization we added)
let scalar_f32_args =
vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(-42.5)))];
let scalar_f32_arg_fields =
vec![Field::new("a", DataType::Float32, false).into()];
let return_field_f32 = Field::new("f", DataType::Float32, false).into();

c.bench_function(&format!("signum f32 scalar: {size}"), |b| {
b.iter(|| {
black_box(
signum
.invoke_with_args(ScalarFunctionArgs {
args: scalar_f32_args.clone(),
arg_fields: scalar_f32_arg_fields.clone(),
number_rows: 1,
return_field: Arc::clone(&return_field_f32),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});

let scalar_f64_args =
vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(-42.5)))];
let scalar_f64_arg_fields =
vec![Field::new("a", DataType::Float64, false).into()];
let return_field_f64 = Field::new("f", DataType::Float64, false).into();

c.bench_function(&format!("signum f64 scalar: {size}"), |b| {
b.iter(|| {
black_box(
signum
.invoke_with_args(ScalarFunctionArgs {
args: scalar_f64_args.clone(),
arg_fields: scalar_f64_arg_fields.clone(),
number_rows: 1,
return_field: Arc::clone(&return_field_f64),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});
}
}

Expand Down
82 changes: 50 additions & 32 deletions datafusion/functions/src/math/signum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,19 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, AsArray};
use arrow::array::AsArray;
use arrow::datatypes::DataType::{Float32, Float64};
use arrow::datatypes::{DataType, Float32Type, Float64Type};

use datafusion_common::{Result, exec_err};
use datafusion_common::utils::take_function_args;
use datafusion_common::{Result, ScalarValue, internal_err};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_macros::user_doc;

use crate::utils::make_scalar_function;

#[user_doc(
doc_section(label = "Math Functions"),
description = r#"Returns the sign of a number.
Expand Down Expand Up @@ -98,41 +97,60 @@ impl ScalarUDFImpl for SignumFunc {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(signum, vec![])(&args.args)
let return_type = args.return_type().clone();
let [arg] = take_function_args(self.name(), args.args)?;

match arg {
ColumnarValue::Scalar(scalar) => {
if scalar.is_null() {
return ColumnarValue::Scalar(ScalarValue::Null)
.cast_to(&return_type, None);
}

match scalar {
ScalarValue::Float64(Some(v)) => {
let result = if v == 0.0 { 0.0 } else { v.signum() };
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result))))
}
ScalarValue::Float32(Some(v)) => {
let result = if v == 0.0 { 0.0 } else { v.signum() };
Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result))))
}
_ => {
internal_err!(
"Unexpected scalar type for signum: {:?}",
scalar.data_type()
)
}
}
}
ColumnarValue::Array(array) => match array.data_type() {
Float64 => Ok(ColumnarValue::Array(Arc::new(
array.as_primitive::<Float64Type>().unary::<_, Float64Type>(
|x: f64| {
if x == 0.0 { 0.0 } else { x.signum() }
},
),
))),
Float32 => Ok(ColumnarValue::Array(Arc::new(
array.as_primitive::<Float32Type>().unary::<_, Float32Type>(
|x: f32| {
if x == 0.0 { 0.0 } else { x.signum() }
},
),
))),
other => {
internal_err!("Unsupported data type {other:?} for function signum")
}
},
}
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

/// signum SQL function
fn signum(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
Float64 => Ok(Arc::new(
args[0]
.as_primitive::<Float64Type>()
.unary::<_, Float64Type>(
|x: f64| {
if x == 0_f64 { 0_f64 } else { x.signum() }
},
),
) as ArrayRef),

Float32 => Ok(Arc::new(
args[0]
.as_primitive::<Float32Type>()
.unary::<_, Float32Type>(
|x: f32| {
if x == 0_f32 { 0_f32 } else { x.signum() }
},
),
) as ArrayRef),

other => exec_err!("Unsupported data type {other:?} for function signum"),
}
}

#[cfg(test)]
mod test {
use std::sync::Arc;
Expand Down