diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 4fb78933d7a5c..978e9f627565c 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -77,6 +77,7 @@ pub mod statistics { pub use datafusion_expr_common::statistics::*; } mod predicate_bounds; +pub mod preimage; pub mod ptr_eq; pub mod test; pub mod tree_node; diff --git a/datafusion/expr/src/preimage.rs b/datafusion/expr/src/preimage.rs new file mode 100644 index 0000000000000..67ca7a91bbf38 --- /dev/null +++ b/datafusion/expr/src/preimage.rs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr_common::interval_arithmetic::Interval; + +use crate::Expr; + +/// Return from [`crate::ScalarUDFImpl::preimage`] +pub enum PreimageResult { + /// No preimage exists for the specified value + None, + /// The expression always evaluates to the specified constant + /// given that `expr` is within the interval + Range { expr: Expr, interval: Box }, +} diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 0654370ac7ebf..870e318a62c3d 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -19,6 +19,7 @@ use crate::async_udf::AsyncScalarUDF; use crate::expr::schema_name_from_exprs_comma_separated_without_space; +use crate::preimage::PreimageResult; use crate::simplify::{ExprSimplifyResult, SimplifyContext}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::udf_eq::UdfEq; @@ -232,6 +233,18 @@ impl ScalarUDF { self.inner.is_nullable(args, schema) } + /// Return a preimage + /// + /// See [`ScalarUDFImpl::preimage`] for more details. + pub fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result { + self.inner.preimage(args, lit_expr, info) + } + /// Invoke the function on `args`, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke_with_args`] for details. @@ -696,6 +709,32 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { Ok(ExprSimplifyResult::Original(args)) } + /// Returns the [preimage] for this function and the specified scalar value, if any. + /// + /// A preimage is a single contiguous [`Interval`] of values where the function + /// will always return `lit_value` + /// + /// Implementations should return intervals with an inclusive lower bound and + /// exclusive upper bound. + /// + /// This rewrite is described in the [ClickHouse Paper] and is particularly + /// useful for simplifying expressions `date_part` or equivalent functions. The + /// idea is that if you have an expression like `date_part(YEAR, k) = 2024` and you + /// can find a [preimage] for `date_part(YEAR, k)`, which is the range of dates + /// covering the entire year of 2024. Thus, you can rewrite the expression to `k + /// >= '2024-01-01' AND k < '2025-01-01' which is often more optimizable. + /// + /// [ClickHouse Paper]: https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf + /// [preimage]: https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image + fn preimage( + &self, + _args: &[Expr], + _lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result { + Ok(PreimageResult::None) + } + /// Returns true if some of this `exprs` subexpressions may not be evaluated /// and thus any side effects (like divide by zero) may not be encountered. /// @@ -926,6 +965,15 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.simplify(args, info) } + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result { + self.inner.preimage(args, lit_expr, info) + } + fn conditional_arguments<'a>( &self, args: &'a [Expr], diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index b9ef69dd08ff6..7bbb7e79d18d6 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -39,7 +39,7 @@ use datafusion_common::{ }; use datafusion_expr::{ BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, and, - binary::BinaryTypeCoercer, lit, or, + binary::BinaryTypeCoercer, lit, or, preimage::PreimageResult, }; use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult}; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; @@ -51,7 +51,6 @@ use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionP use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; -use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::SimplifyContext; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::unwrap_cast::{ @@ -59,6 +58,10 @@ use crate::simplify_expressions::unwrap_cast::{ is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist, unwrap_cast_in_comparison_for_binary, }; +use crate::{ + analyzer::type_coercion::TypeCoercionRewriter, + simplify_expressions::udf_preimage::rewrite_with_preimage, +}; use datafusion_expr::expr_rewriter::rewrite_with_guarantees_map; use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; @@ -1969,12 +1972,85 @@ impl TreeNodeRewriter for Simplifier<'_> { })) } + // ======================================= + // preimage_in_comparison + // ======================================= + // + // For case: + // date_part('YEAR', expr) op literal + // + // For details see datafusion_expr::ScalarUDFImpl::preimage + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + use datafusion_expr::Operator::*; + let is_preimage_op = matches!( + op, + Eq | NotEq + | Lt + | LtEq + | Gt + | GtEq + | IsDistinctFrom + | IsNotDistinctFrom + ); + if !is_preimage_op || is_null(&right) { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))); + } + + if let PreimageResult::Range { interval, expr } = + get_preimage(left.as_ref(), right.as_ref(), info)? + { + rewrite_with_preimage(*interval, op, expr)? + } else if let Some(swapped) = op.swap() { + if let PreimageResult::Range { interval, expr } = + get_preimage(right.as_ref(), left.as_ref(), info)? + { + rewrite_with_preimage(*interval, swapped, expr)? + } else { + Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right })) + } + } else { + Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right })) + } + } + // no additional rewrites possible expr => Transformed::no(expr), }) } } +fn get_preimage( + left_expr: &Expr, + right_expr: &Expr, + info: &SimplifyContext, +) -> Result { + let Expr::ScalarFunction(ScalarFunction { func, args }) = left_expr else { + return Ok(PreimageResult::None); + }; + if !is_literal_or_literal_cast(right_expr) { + return Ok(PreimageResult::None); + } + if func.signature().volatility != Volatility::Immutable { + return Ok(PreimageResult::None); + } + func.preimage(args, right_expr, info) +} + +fn is_literal_or_literal_cast(expr: &Expr) -> bool { + match expr { + Expr::Literal(_, _) => true, + Expr::Cast(Cast { expr, .. }) => matches!(expr.as_ref(), Expr::Literal(_, _)), + Expr::TryCast(TryCast { expr, .. }) => { + matches!(expr.as_ref(), Expr::Literal(_, _)) + } + _ => false, + } +} + fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { match expr { Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)), diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 3ab76119cca84..b85b000821ad8 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -24,6 +24,7 @@ mod regex; pub mod simplify_exprs; pub mod simplify_literal; mod simplify_predicates; +mod udf_preimage; mod unwrap_cast; mod utils; diff --git a/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs b/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs new file mode 100644 index 0000000000000..e0837196ca990 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs @@ -0,0 +1,364 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{Result, internal_err, tree_node::Transformed}; +use datafusion_expr::{Expr, Operator, and, lit, or}; +use datafusion_expr_common::interval_arithmetic::Interval; + +/// Rewrites a binary expression using its "preimage" +/// +/// Specifically it rewrites expressions of the form ` OP x` (e.g. ` = +/// x`) where `` is known to have a pre-image (aka the entire single +/// range for which it is valid) and `x` is not `NULL` +/// +/// For details see [`datafusion_expr::ScalarUDFImpl::preimage`] +/// +pub(super) fn rewrite_with_preimage( + preimage_interval: Interval, + op: Operator, + expr: Expr, +) -> Result> { + let (lower, upper) = preimage_interval.into_bounds(); + let (lower, upper) = (lit(lower), lit(upper)); + + let rewritten_expr = match op { + // < x ==> < lower + Operator::Lt => expr.lt(lower), + // >= x ==> >= lower + Operator::GtEq => expr.gt_eq(lower), + // > x ==> >= upper + Operator::Gt => expr.gt_eq(upper), + // <= x ==> < upper + Operator::LtEq => expr.lt(upper), + // = x ==> ( >= lower) and ( < upper) + Operator::Eq => and(expr.clone().gt_eq(lower), expr.lt(upper)), + // != x ==> ( < lower) or ( >= upper) + Operator::NotEq => or(expr.clone().lt(lower), expr.gt_eq(upper)), + // is not distinct from x ==> ( is NULL and x is NULL) or (( >= lower) and ( < upper)) + // but since x is always not NULL => ( is not NULL) and ( >= lower) and ( < upper) + Operator::IsNotDistinctFrom => expr + .clone() + .is_not_null() + .and(expr.clone().gt_eq(lower)) + .and(expr.lt(upper)), + // is distinct from x ==> ( < lower) or ( >= upper) or ( is NULL and x is not NULL) or ( is not NULL and x is NULL) + // but given that x is always not NULL => ( < lower) or ( >= upper) or ( is NULL) + Operator::IsDistinctFrom => expr + .clone() + .lt(lower) + .or(expr.clone().gt_eq(upper)) + .or(expr.is_null()), + _ => return internal_err!("Expect comparison operators"), + }; + Ok(Transformed::yes(rewritten_expr)) +} + +#[cfg(test)] +mod test { + use std::any::Any; + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue}; + use datafusion_expr::{ + ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, and, binary_expr, col, lit, preimage::PreimageResult, + simplify::SimplifyContext, + }; + + use super::Interval; + use crate::simplify_expressions::ExprSimplifier; + + fn is_distinct_from(left: Expr, right: Expr) -> Expr { + binary_expr(left, Operator::IsDistinctFrom, right) + } + + fn is_not_distinct_from(left: Expr, right: Expr) -> Expr { + binary_expr(left, Operator::IsNotDistinctFrom, right) + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct PreimageUdf { + /// Defaults to an exact signature with one Int32 argument and Immutable volatility + signature: Signature, + /// If true, returns a preimage; otherwise, returns None + enabled: bool, + } + + impl PreimageUdf { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), + enabled: true, + } + } + + /// Set the enabled flag + fn with_enabled(mut self, enabled: bool) -> Self { + self.enabled = enabled; + self + } + + /// Set the volatility + fn with_volatility(mut self, volatility: Volatility) -> Self { + self.signature.volatility = volatility; + self + } + } + + impl ScalarUDFImpl for PreimageUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "preimage_func" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(500)))) + } + + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result { + if !self.enabled { + return Ok(PreimageResult::None); + } + if args.len() != 1 { + return Ok(PreimageResult::None); + } + + let expr = args.first().cloned().expect("Should be column expression"); + match lit_expr { + Expr::Literal(ScalarValue::Int32(Some(500)), _) => { + Ok(PreimageResult::Range { + expr, + interval: Box::new(Interval::try_new( + ScalarValue::Int32(Some(100)), + ScalarValue::Int32(Some(200)), + )?), + }) + } + _ => Ok(PreimageResult::None), + } + } + } + + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { + let simplify_context = SimplifyContext::default().with_schema(Arc::clone(schema)); + ExprSimplifier::new(simplify_context) + .simplify(expr) + .unwrap() + } + + fn preimage_udf_expr() -> Expr { + ScalarUDF::new_from_impl(PreimageUdf::new()).call(vec![col("x")]) + } + + fn non_immutable_udf_expr() -> Expr { + ScalarUDF::new_from_impl(PreimageUdf::new().with_volatility(Volatility::Volatile)) + .call(vec![col("x")]) + } + + fn no_preimage_udf_expr() -> Expr { + ScalarUDF::new_from_impl(PreimageUdf::new().with_enabled(false)) + .call(vec![col("x")]) + } + + fn test_schema() -> DFSchemaRef { + Arc::new( + DFSchema::from_unqualified_fields( + vec![Field::new("x", DataType::Int32, true)].into(), + Default::default(), + ) + .unwrap(), + ) + } + + fn test_schema_xy() -> DFSchemaRef { + Arc::new( + DFSchema::from_unqualified_fields( + vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Int32, false), + ] + .into(), + Default::default(), + ) + .unwrap(), + ) + } + + #[test] + fn test_preimage_eq_rewrite() { + // Equality rewrite when preimage and column expression are available. + let schema = test_schema(); + let expr = preimage_udf_expr().eq(lit(500)); + let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_noteq_rewrite() { + // Inequality rewrite expands to disjoint ranges. + let schema = test_schema(); + let expr = preimage_udf_expr().not_eq(lit(500)); + let expected = col("x").lt(lit(100)).or(col("x").gt_eq(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_eq_rewrite_swapped() { + // Equality rewrite works when the literal appears on the left. + let schema = test_schema(); + let expr = lit(500).eq(preimage_udf_expr()); + let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_lt_rewrite() { + // Less-than comparison rewrites to the lower bound. + let schema = test_schema(); + let expr = preimage_udf_expr().lt(lit(500)); + let expected = col("x").lt(lit(100)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_lteq_rewrite() { + // Less-than-or-equal comparison rewrites to the upper bound. + let schema = test_schema(); + let expr = preimage_udf_expr().lt_eq(lit(500)); + let expected = col("x").lt(lit(200)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_gt_rewrite() { + // Greater-than comparison rewrites to the upper bound (inclusive). + let schema = test_schema(); + let expr = preimage_udf_expr().gt(lit(500)); + let expected = col("x").gt_eq(lit(200)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_gteq_rewrite() { + // Greater-than-or-equal comparison rewrites to the lower bound. + let schema = test_schema(); + let expr = preimage_udf_expr().gt_eq(lit(500)); + let expected = col("x").gt_eq(lit(100)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_is_not_distinct_from_rewrite() { + // IS NOT DISTINCT FROM rewrites to equality plus expression not-null check + // for non-null literal RHS. + let schema = test_schema(); + let expr = is_not_distinct_from(preimage_udf_expr(), lit(500)); + let expected = col("x") + .is_not_null() + .and(col("x").gt_eq(lit(100))) + .and(col("x").lt(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_is_distinct_from_rewrite() { + // IS DISTINCT FROM adds an explicit NULL branch for the column. + let schema = test_schema(); + let expr = is_distinct_from(preimage_udf_expr(), lit(500)); + let expected = col("x") + .lt(lit(100)) + .or(col("x").gt_eq(lit(200))) + .or(col("x").is_null()); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_non_literal_rhs_no_rewrite() { + // Non-literal RHS should not be rewritten. + let schema = test_schema_xy(); + let expr = preimage_udf_expr().eq(col("y")); + let expected = expr.clone(); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_null_literal_no_rewrite_distinct_ops() { + // NULL literal RHS should not be rewritten for DISTINCTness operators: + // - `expr IS DISTINCT FROM NULL` <=> `NOT (expr IS NULL)` + // - `expr IS NOT DISTINCT FROM NULL` <=> `expr IS NULL` + // + // For normal comparisons (=, !=, <, <=, >, >=), `expr OP NULL` evaluates to NULL + // under SQL tri-state logic, and DataFusion's simplifier constant-folds it. + // https://docs.rs/datafusion/latest/datafusion/physical_optimizer/pruning/struct.PruningPredicate.html#boolean-tri-state-logic + + let schema = test_schema(); + + let expr = is_distinct_from(preimage_udf_expr(), lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(expr.clone(), &schema), expr); + + let expr = + is_not_distinct_from(preimage_udf_expr(), lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(expr.clone(), &schema), expr); + } + + #[test] + fn test_preimage_non_immutable_no_rewrite() { + // Non-immutable UDFs should not participate in preimage rewrites. + let schema = test_schema(); + let expr = non_immutable_udf_expr().eq(lit(500)); + let expected = expr.clone(); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_no_preimage_no_rewrite() { + // If the UDF provides no preimage, the expression should remain unchanged. + let schema = test_schema(); + let expr = no_preimage_udf_expr().eq(lit(500)); + let expected = expr.clone(); + + assert_eq!(optimize_test(expr, &schema), expected); + } +}