Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,25 @@ impl ScalarUDF {
self.inner.is_nullable(args, schema)
}

/// Return a preimage
///
/// See [`ScalarUDFImpl::preimage`] for more details.
pub fn preimage(
&self,
args: &[Expr],
lit_expr: &Expr,
info: &SimplifyContext,
) -> Result<Option<Interval>> {
self.inner.preimage(args, lit_expr, info)
}

/// Return inner column from function args
///
/// See [`ScalarUDFImpl::column_expr`]
pub fn column_expr(&self, args: &[Expr]) -> Option<Expr> {
self.inner.column_expr(args)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_with_args`] for details.
Expand Down Expand Up @@ -696,6 +715,40 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
Ok(ExprSimplifyResult::Original(args))
}

/// Returns the [preimage] for this function and the specified scalar value, if any.
///
/// A preimage is a single contiguous [`Interval`] of values where the function
/// will always return `lit_value`
///
/// Implementations should return intervals with an inclusive lower bound and
/// exclusive upper bound.
///
/// This rewrite is described in the [ClickHouse Paper] and is particularly
/// useful for simplifying expressions `date_part` or equivalent functions. The
/// idea is that if you have an expression like `date_part(YEAR, k) = 2024` and you
/// can find a [preimage] for `date_part(YEAR, k)`, which is the range of dates
/// covering the entire year of 2024. Thus, you can rewrite the expression to `k
/// >= '2024-01-01' AND k < '2025-01-01' which is often more optimizable.
///
/// Implementations must also provide [`ScalarUDFImpl::column_expr`] so the
/// optimizer can identify which argument maps to the preimage interval.
///
/// [ClickHouse Paper]: https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf
/// [preimage]: https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image
fn preimage(
&self,
_args: &[Expr],
_lit_expr: &Expr,
_info: &SimplifyContext,
) -> Result<Option<Interval>> {
Ok(None)
}

// Return the inner column expression from this function
fn column_expr(&self, _args: &[Expr]) -> Option<Expr> {
None
}

/// Returns true if some of this `exprs` subexpressions may not be evaluated
/// and thus any side effects (like divide by zero) may not be encountered.
///
Expand Down Expand Up @@ -926,6 +979,19 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
self.inner.simplify(args, info)
}

fn preimage(
&self,
args: &[Expr],
lit_expr: &Expr,
info: &SimplifyContext,
) -> Result<Option<Interval>> {
self.inner.preimage(args, lit_expr, info)
}

fn column_expr(&self, args: &[Expr]) -> Option<Expr> {
self.inner.column_expr(args)
}

fn conditional_arguments<'a>(
&self,
args: &'a [Expr],
Expand Down
99 changes: 97 additions & 2 deletions datafusion/functions/src/datetime/date_part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
use arrow::datatypes::{
DataType, Field, FieldRef, IntervalUnit as ArrowIntervalUnit, TimeUnit,
};
use arrow::temporal_conversions::{
MICROSECONDS_IN_DAY, MILLISECONDS_IN_DAY, NANOSECONDS_IN_DAY, SECONDS_IN_DAY,
};
use chrono::{Datelike, NaiveDate};
use datafusion_common::types::{NativeType, logical_date};

use datafusion_common::{
Expand All @@ -44,9 +48,10 @@ use datafusion_common::{
types::logical_string,
utils::take_function_args,
};
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{
ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDFImpl, Signature,
TypeSignature, Volatility, interval_arithmetic,
};
use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
use datafusion_macros::user_doc;
Expand Down Expand Up @@ -237,6 +242,67 @@ impl ScalarUDFImpl for DatePartFunc {
})
}

// Only casting the year is supported since pruning other IntervalUnit is not possible
// date_part(col, YEAR) = 2024 => col >= '2024-01-01' and col < '2025-01-01'
// But for anything less than YEAR simplifying is not possible without specifying the bigger interval
// date_part(col, MONTH) = 1 => col = '2023-01-01' or col = '2024-01-01' or ... or col = '3000-01-01'
fn preimage(
&self,
args: &[Expr],
lit_expr: &Expr,
info: &SimplifyContext,
) -> Result<Option<interval_arithmetic::Interval>> {
let [part, col_expr] = take_function_args(self.name(), args)?;

// Get the interval unit from the part argument
let interval_unit = part
.as_literal()
.and_then(|sv| sv.try_as_str().flatten())
.map(part_normalization)
.and_then(|s| IntervalUnit::from_str(s).ok());

// only support extracting year
match interval_unit {
Some(IntervalUnit::Year) => (),
_ => return Ok(None),
}

// Check if the argument is a literal (e.g. date_part(YEAR, col) = 2024)
let Some(argument_literal) = lit_expr.as_literal() else {
return Ok(None);
};

// Extract i32 year from Scalar value
let year = match argument_literal {
ScalarValue::Int32(Some(y)) => *y,
_ => return Ok(None),
};

// Can only extract year from Date32/64 and Timestamp column
let target_type = match info.get_data_type(col_expr)? {
Date32 | Date64 | Timestamp(_, _) => &info.get_data_type(col_expr)?,
_ => return Ok(None),
};

// Compute the Interval bounds
let start_time =
NaiveDate::from_ymd_opt(year, 1, 1).expect("Expect computed start time");
let end_time = start_time
.with_year(year + 1)
.expect("Expect computed end time");

// Convert to ScalarValues
let lower = date_to_scalar(start_time, target_type)
.expect("Expect preimage interval lower bound");
let upper = date_to_scalar(end_time, target_type)
.expect("Expect preimage interval upper bound");
Ok(Some(interval_arithmetic::Interval::try_new(lower, upper)?))
}

fn column_expr(&self, args: &[Expr]) -> Option<Expr> {
Some(args[1].clone())
}

fn aliases(&self) -> &[String] {
&self.aliases
}
Expand All @@ -251,6 +317,35 @@ fn is_epoch(part: &str) -> bool {
matches!(part.to_lowercase().as_str(), "epoch")
}

fn date_to_scalar(date: NaiveDate, target_type: &DataType) -> Option<ScalarValue> {
let days = date
.signed_duration_since(NaiveDate::from_epoch_days(0)?)
.num_days();

Some(match target_type {
Date32 => ScalarValue::Date32(Some(days as i32)),
Date64 => ScalarValue::Date64(Some(days * MILLISECONDS_IN_DAY)),
Timestamp(unit, tz) => match unit {
Second => {
ScalarValue::TimestampSecond(Some(days * SECONDS_IN_DAY), tz.clone())
}
Millisecond => ScalarValue::TimestampMillisecond(
Some(days * MILLISECONDS_IN_DAY),
tz.clone(),
),
Microsecond => ScalarValue::TimestampMicrosecond(
Some(days * MICROSECONDS_IN_DAY),
tz.clone(),
),
Nanosecond => ScalarValue::TimestampNanosecond(
Some(days * NANOSECONDS_IN_DAY),
tz.clone(),
),
},
_ => return None,
})
}

// Try to remove quote if exist, if the quote is invalid, return original string and let the downstream function handle the error
fn part_normalization(part: &str) -> &str {
part.strip_prefix(|c| c == '\'' || c == '\"')
Expand Down
96 changes: 94 additions & 2 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use datafusion_common::{
};
use datafusion_expr::{
BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, and,
binary::BinaryTypeCoercer, lit, or,
binary::BinaryTypeCoercer, interval_arithmetic::Interval, lit, or,
};
use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult};
use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
Expand All @@ -51,14 +51,17 @@ use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionP

use super::inlist_simplifier::ShortenInListSimplifier;
use super::utils::*;
use crate::analyzer::type_coercion::TypeCoercionRewriter;
use crate::simplify_expressions::SimplifyContext;
use crate::simplify_expressions::regex::simplify_regex_expr;
use crate::simplify_expressions::unwrap_cast::{
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
unwrap_cast_in_comparison_for_binary,
};
use crate::{
analyzer::type_coercion::TypeCoercionRewriter,
simplify_expressions::udf_preimage::rewrite_with_preimage,
};
use datafusion_expr::expr_rewriter::rewrite_with_guarantees_map;
use datafusion_expr_common::casts::try_cast_literal_to_type;
use indexmap::IndexSet;
Expand Down Expand Up @@ -1969,12 +1972,101 @@ impl TreeNodeRewriter for Simplifier<'_> {
}))
}

// =======================================
// preimage_in_comparison
// =======================================
//
// For case:
// date_part('YEAR', expr) op literal
//
// Background:
// Datasources such as Parquet can prune partitions using simple predicates,
// but they cannot do so for complex expressions.
// For a complex predicate like `date_part('YEAR', c1) < 2000`, pruning is not possible.
// After rewriting it to `c1 < 2000-01-01`, pruning becomes feasible.
// Rewrites use inclusive lower and exclusive upper bounds when
// translating an equality into a range.
// NOTE: we only consider immutable UDFs with literal RHS values and
// UDFs that provide both `preimage` and `column_expr`.
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
use datafusion_expr::Operator::*;
let is_preimage_op = matches!(
op,
Eq | NotEq
| Lt
| LtEq
| Gt
| GtEq
| IsDistinctFrom
| IsNotDistinctFrom
);
if !is_preimage_op {
return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr {
left,
op,
right,
})));
}

if let (Some(interval), Some(col_expr)) =
get_preimage(left.as_ref(), right.as_ref(), info)?
{
rewrite_with_preimage(info, interval, op, Box::new(col_expr))?
} else if let Some(swapped) = op.swap() {
if let (Some(interval), Some(col_expr)) =
get_preimage(right.as_ref(), left.as_ref(), info)?
{
rewrite_with_preimage(
info,
interval,
swapped,
Box::new(col_expr),
)?
} else {
Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right }))
}
} else {
Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right }))
}
}

// no additional rewrites possible
expr => Transformed::no(expr),
})
}
}

fn get_preimage(
left_expr: &Expr,
right_expr: &Expr,
info: &SimplifyContext,
) -> Result<(Option<Interval>, Option<Expr>)> {
let Expr::ScalarFunction(ScalarFunction { func, args }) = left_expr else {
return Ok((None, None));
};
if !is_literal_or_literal_cast(right_expr) {
return Ok((None, None));
}
if func.signature().volatility != Volatility::Immutable {
return Ok((None, None));
}
Ok((
func.preimage(args, right_expr, info)?,
func.column_expr(args),
))
}

fn is_literal_or_literal_cast(expr: &Expr) -> bool {
match expr {
Expr::Literal(_, _) => true,
Expr::Cast(Cast { expr, .. }) => matches!(expr.as_ref(), Expr::Literal(_, _)),
Expr::TryCast(TryCast { expr, .. }) => {
matches!(expr.as_ref(), Expr::Literal(_, _))
}
_ => false,
}
}

fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option<String>)> {
match expr {
Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)),
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/src/simplify_expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod regex;
pub mod simplify_exprs;
pub mod simplify_literal;
mod simplify_predicates;
mod udf_preimage;
mod unwrap_cast;
mod utils;

Expand Down
Loading