diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 1a273ad033..705be019bb 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -234,6 +234,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.CreateArray.enabled` | Enable Comet acceleration for `CreateArray` | true | | `spark.comet.expression.CreateNamedStruct.enabled` | Enable Comet acceleration for `CreateNamedStruct` | true | | `spark.comet.expression.DateAdd.enabled` | Enable Comet acceleration for `DateAdd` | true | +| `spark.comet.expression.DateFromUnixDate.enabled` | Enable Comet acceleration for `DateFromUnixDate` | true | | `spark.comet.expression.DateSub.enabled` | Enable Comet acceleration for `DateSub` | true | | `spark.comet.expression.DayOfMonth.enabled` | Enable Comet acceleration for `DayOfMonth` | true | | `spark.comet.expression.DayOfWeek.enabled` | Enable Comet acceleration for `DayOfWeek` | true | @@ -266,6 +267,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.IsNull.enabled` | Enable Comet acceleration for `IsNull` | true | | `spark.comet.expression.JsonToStructs.enabled` | Enable Comet acceleration for `JsonToStructs` | true | | `spark.comet.expression.KnownFloatingPointNormalized.enabled` | Enable Comet acceleration for `KnownFloatingPointNormalized` | true | +| `spark.comet.expression.LastDay.enabled` | Enable Comet acceleration for `LastDay` | true | | `spark.comet.expression.Length.enabled` | Enable Comet acceleration for `Length` | true | | `spark.comet.expression.LessThan.enabled` | Enable Comet acceleration for `LessThan` | true | | `spark.comet.expression.LessThanOrEqual.enabled` | Enable Comet acceleration for `LessThanOrEqual` | true | diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 75c53198b8..56569bc69c 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -44,6 +44,7 @@ use datafusion_spark::function::bitwise::bit_get::SparkBitGet; use datafusion_spark::function::bitwise::bitwise_not::SparkBitwiseNot; use datafusion_spark::function::datetime::date_add::SparkDateAdd; use datafusion_spark::function::datetime::date_sub::SparkDateSub; +use datafusion_spark::function::datetime::last_day::SparkLastDay; use datafusion_spark::function::hash::sha1::SparkSha1; use datafusion_spark::function::hash::sha2::SparkSha2; use datafusion_spark::function::math::expm1::SparkExpm1; @@ -345,6 +346,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateSub::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkLastDay::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha1::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default())); diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 8384a4646a..1bcf4701c2 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, - spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateTrunc, SparkSizeFunc, - SparkStringSpace, + spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateFromUnixDate, + SparkDateTrunc, SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -192,6 +192,7 @@ pub fn create_comet_physical_fun_with_eval_mode( fn all_scalar_functions() -> Vec> { vec![ Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), + Arc::new(ScalarUDF::new_from_impl(SparkDateFromUnixDate::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())), diff --git a/native/spark-expr/src/datetime_funcs/date_from_unix_date.rs b/native/spark-expr/src/datetime_funcs/date_from_unix_date.rs new file mode 100644 index 0000000000..0671a9001d --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/date_from_unix_date.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, Date32Array, Int32Array}; +use arrow::datatypes::DataType; +use datafusion::common::{utils::take_function_args, DataFusionError, Result}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible date_from_unix_date function. +/// Converts an integer representing days since Unix epoch (1970-01-01) to a Date32 value. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkDateFromUnixDate { + signature: Signature, + aliases: Vec, +} + +impl SparkDateFromUnixDate { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), + aliases: vec![], + } + } +} + +impl Default for SparkDateFromUnixDate { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for SparkDateFromUnixDate { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "date_from_unix_date" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Date32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [unix_date] = take_function_args(self.name(), args.args)?; + match unix_date { + ColumnarValue::Array(arr) => { + let int_array = arr.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Execution( + "date_from_unix_date expects Int32Array input".to_string(), + ) + })?; + + // Date32 and Int32 both represent days since epoch, so we can directly + // reinterpret the values. The only operation needed is creating a Date32Array + // from the same underlying i32 values. + let date_array = + Date32Array::new(int_array.values().clone(), int_array.nulls().cloned()); + + Ok(ColumnarValue::Array(Arc::new(date_array))) + } + ColumnarValue::Scalar(scalar) => { + // Handle scalar case by converting to single-element array and back + let arr = scalar.to_array()?; + let int_array = arr.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Execution( + "date_from_unix_date expects Int32 scalar input".to_string(), + ) + })?; + + let date_array = + Date32Array::new(int_array.values().clone(), int_array.nulls().cloned()); + + Ok(ColumnarValue::Array(Arc::new(date_array))) + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} diff --git a/native/spark-expr/src/datetime_funcs/mod.rs b/native/spark-expr/src/datetime_funcs/mod.rs index ef8041e5fe..6022c5c2c7 100644 --- a/native/spark-expr/src/datetime_funcs/mod.rs +++ b/native/spark-expr/src/datetime_funcs/mod.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +mod date_from_unix_date; mod date_trunc; mod extract_date_part; mod timestamp_trunc; +pub use date_from_unix_date::SparkDateFromUnixDate; pub use date_trunc::SparkDateTrunc; pub use extract_date_part::SparkHour; pub use extract_date_part::SparkMinute; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index f26fd911d8..1ac60343e4 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -69,7 +69,9 @@ pub use comet_scalar_funcs::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, register_all_comet_functions, }; -pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr}; +pub use datetime_funcs::{ + SparkDateFromUnixDate, SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr, +}; pub use error::{SparkError, SparkResult}; pub use hash_funcs::*; pub use json_funcs::{FromJson, ToJson}; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e50b1d80e6..80268aedb4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -185,8 +185,10 @@ object QueryPlanSerde extends Logging with CometExprShim { private val temporalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[DateAdd] -> CometDateAdd, + classOf[DateFromUnixDate] -> CometDateFromUnixDate, classOf[DateSub] -> CometDateSub, classOf[FromUnixTime] -> CometFromUnixTime, + classOf[LastDay] -> CometLastDay, classOf[Hour] -> CometHour, classOf[Minute] -> CometMinute, classOf[Second] -> CometSecond, diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala b/spark/src/main/scala/org/apache/comet/serde/datetime.scala index ef2b0f793c..2f4918342b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala +++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateSub, DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, Literal, Minute, Month, Quarter, Second, TruncDate, TruncTimestamp, WeekDay, WeekOfYear, Year} +import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateFromUnixDate, DateSub, DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, LastDay, Literal, Minute, Month, Quarter, Second, TruncDate, TruncTimestamp, WeekDay, WeekOfYear, Year} import org.apache.spark.sql.types.{DateType, IntegerType} import org.apache.spark.unsafe.types.UTF8String @@ -258,6 +258,10 @@ object CometDateAdd extends CometScalarFunction[DateAdd]("date_add") object CometDateSub extends CometScalarFunction[DateSub]("date_sub") +object CometLastDay extends CometScalarFunction[LastDay]("last_day") + +object CometDateFromUnixDate extends CometScalarFunction[DateFromUnixDate]("date_from_unix_date") + object CometTruncDate extends CometExpressionSerde[TruncDate] { val supportedFormats: Seq[String] = diff --git a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala index 9a23c76d82..81293c046f 100644 --- a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala @@ -21,7 +21,7 @@ package org.apache.comet import scala.util.Random -import org.apache.spark.sql.{CometTestBase, SaveMode} +import org.apache.spark.sql.{CometTestBase, Row, SaveMode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataTypes, StructField, StructType} @@ -122,4 +122,66 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH StructField("fmt", DataTypes.StringType, true))) FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, DataGenOptions()) } + + test("last_day") { + val r = new Random(42) + val schema = StructType(Seq(StructField("c0", DataTypes.DateType, true))) + val df = FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, DataGenOptions()) + df.createOrReplaceTempView("tbl") + + // Basic test with random dates + checkSparkAnswerAndOperator("SELECT c0, last_day(c0) FROM tbl ORDER BY c0") + + // Disable constant folding to ensure literal expressions are executed by Comet + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + // Test with literal dates - various months + checkSparkAnswerAndOperator( + "SELECT last_day(DATE('2024-01-15')), last_day(DATE('2024-02-15')), last_day(DATE('2024-12-01'))") + + // Test leap year handling (February) + checkSparkAnswerAndOperator( + "SELECT last_day(DATE('2024-02-01')), last_day(DATE('2023-02-01'))") + + // Test null handling + checkSparkAnswerAndOperator("SELECT last_day(NULL)") + } + } + + test("date_from_unix_date") { + // Create test data with unix dates in a reasonable range (1900-2100) + // -25567 = 1900-01-01, 47482 = 2100-01-01 + val r = new Random(42) + val testData = (1 to 1000).map { _ => + val unixDate = r.nextInt(73049) - 25567 // range from 1900 to 2100 + Row(if (r.nextDouble() < 0.1) null else unixDate) + } + val schema = StructType(Seq(StructField("c0", DataTypes.IntegerType, true))) + val df = spark.createDataFrame(spark.sparkContext.parallelize(testData), schema) + df.createOrReplaceTempView("tbl") + + // Basic test with random unix dates in a reasonable range + checkSparkAnswerAndOperator("SELECT c0, date_from_unix_date(c0) FROM tbl ORDER BY c0") + + // Disable constant folding to ensure literal expressions are executed by Comet + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + // Test epoch (0 = 1970-01-01) + checkSparkAnswerAndOperator("SELECT date_from_unix_date(0)") + + // Test day after epoch (1 = 1970-01-02) + checkSparkAnswerAndOperator("SELECT date_from_unix_date(1)") + + // Test day before epoch (-1 = 1969-12-31) + checkSparkAnswerAndOperator("SELECT date_from_unix_date(-1)") + + // Test a known date (18993 = 2022-01-01, calculated as days from 1970-01-01) + checkSparkAnswerAndOperator("SELECT date_from_unix_date(18993)") + + // Test null handling + checkSparkAnswerAndOperator("SELECT date_from_unix_date(NULL)") + } + } }