diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index e4980728b18a0..b9bde1454994c 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -147,7 +147,7 @@ where if scalar.is_null() { // Null scalar is castable to any numeric, creating a non-null expression. // Provide null array explicitly to make result null - PrimitiveArray::::new_null(1) + PrimitiveArray::::new_null(left.len()) } else { let right = R::Native::try_from(scalar.clone()).map_err(|_| { DataFusionError::NotImplemented(format!( @@ -363,12 +363,30 @@ pub mod test { }; } - use arrow::datatypes::DataType; + use arrow::{ + array::Int32Array, + datatypes::{DataType, Int32Type}, + }; use itertools::Either; pub(crate) use test_function; use super::*; + #[test] + fn test_calculate_binary_math_scalar_null() { + let left = Int32Array::from(vec![1, 2]); + let right = ColumnarValue::Scalar(ScalarValue::Int32(None)); + let result = calculate_binary_math::( + &left, + &right, + |x, y| Ok(x + y), + ) + .unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result.null_count(), 2); + } + #[test] fn string_to_int_type() { let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap();