diff --git a/Cargo.lock b/Cargo.lock index 0e9337b50e6f2..8d9a4634231e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2567,6 +2567,7 @@ dependencies = [ "chrono", "crc32fast", "criterion", + "datafusion", "datafusion-catalog", "datafusion-common", "datafusion-execution", diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index ad2620a532f24..dcb586ee809c2 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -29,6 +29,10 @@ edition = { workspace = true } [package.metadata.docs.rs] all-features = true +[features] +default = [] +core = ["datafusion"] + # Note: add additional linter rules in lib.rs. # Rust does not support workspace + new linter rules in subcrates yet # https://github.com/rust-lang/cargo/issues/13157 @@ -43,6 +47,8 @@ arrow = { workspace = true } bigdecimal = { workspace = true } chrono = { workspace = true } crc32fast = "1.4" +# Optional dependency for SessionStateBuilderSpark extension trait +datafusion = { workspace = true, optional = true, default-features = false } datafusion-catalog = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } @@ -59,6 +65,8 @@ url = { workspace = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } criterion = { workspace = true } +# for SessionStateBuilderSpark tests +datafusion = { workspace = true, default-features = false } [[bench]] harness = false diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs index f67367734cf93..ff32a9bb1fc69 100644 --- a/datafusion/spark/src/lib.rs +++ b/datafusion/spark/src/lib.rs @@ -93,10 +93,31 @@ //! ``` //! //![`Expr`]: datafusion_expr::Expr +//! +//! # Example: enabling Apache Spark features with SessionStateBuilder +//! +//! The recommended way to enable Apache Spark compatibility is to use the +//! `SessionStateBuilderSpark` extension trait. This registers all +//! Apache Spark functions (scalar, aggregate, window, and table) as well as the Apache Spark +//! expression planner. +//! +//! Enable the `core` feature in your `Cargo.toml`: +//! ```toml +//! datafusion-spark = { version = "X", features = ["core"] } +//! ``` +//! +//! Then use the extension trait - see [`SessionStateBuilderSpark::with_spark_features`] +//! for an example. pub mod function; pub mod planner; +#[cfg(feature = "core")] +mod session_state; + +#[cfg(feature = "core")] +pub use session_state::SessionStateBuilderSpark; + use datafusion_catalog::TableFunction; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; diff --git a/datafusion/spark/src/session_state.rs b/datafusion/spark/src/session_state.rs new file mode 100644 index 0000000000000..e39de3a5888ea --- /dev/null +++ b/datafusion/spark/src/session_state.rs @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::execution::SessionStateBuilder; + +use crate::planner::SparkFunctionPlanner; +use crate::{ + all_default_aggregate_functions, all_default_scalar_functions, + all_default_table_functions, all_default_window_functions, +}; + +/// Extension trait for adding Apache Spark features to [`SessionStateBuilder`]. +/// +/// This trait provides a convenient way to register all Apache Spark-compatible +/// functions and planners with a DataFusion session. +/// +/// # Example +/// +/// ```rust +/// use datafusion::execution::SessionStateBuilder; +/// use datafusion_spark::SessionStateBuilderSpark; +/// +/// // Create a SessionState with Apache Spark features enabled +/// // note: the order matters here, `with_spark_features` should be +/// // called after `with_default_features` to overwrite any existing functions +/// let state = SessionStateBuilder::new() +/// .with_default_features() +/// .with_spark_features() +/// .build(); +/// ``` +pub trait SessionStateBuilderSpark { + /// Adds all expr_planners, scalar, aggregate, window and table functions + /// compatible with Apache Spark. + /// + /// Note: This overwrites any previously registered items with the same name. + fn with_spark_features(self) -> Self; +} + +impl SessionStateBuilderSpark for SessionStateBuilder { + fn with_spark_features(mut self) -> Self { + self.expr_planners() + .get_or_insert_with(Vec::new) + // planners are evaluated in order of insertion. Push Apache Spark function planner to the front + // to take precedence over others + .insert(0, Arc::new(SparkFunctionPlanner)); + + self.scalar_functions() + .get_or_insert_with(Vec::new) + .extend(all_default_scalar_functions()); + + self.aggregate_functions() + .get_or_insert_with(Vec::new) + .extend(all_default_aggregate_functions()); + + self.window_functions() + .get_or_insert_with(Vec::new) + .extend(all_default_window_functions()); + + self.table_functions() + .get_or_insert_with(HashMap::new) + .extend( + all_default_table_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)), + ); + + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_session_state_with_spark_features() { + let state = SessionStateBuilder::new().with_spark_features().build(); + + assert!( + state.scalar_functions().contains_key("sha2"), + "Apache Spark scalar function 'sha2' should be registered" + ); + + assert!( + state.aggregate_functions().contains_key("try_sum"), + "Apache Spark aggregate function 'try_sum' should be registered" + ); + + assert!( + !state.expr_planners().is_empty(), + "Apache Spark expr planners should be registered" + ); + } +} diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 13ae6e6a57e01..ca5c126b91d1a 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -47,7 +47,7 @@ bytes = { workspace = true, optional = true } chrono = { workspace = true, optional = true } clap = { version = "4.5.53", features = ["derive", "env"] } datafusion = { workspace = true, default-features = true, features = ["avro"] } -datafusion-spark = { workspace = true, default-features = true } +datafusion-spark = { workspace = true, features = ["core"] } datafusion-substrait = { workspace = true, default-features = true } futures = { workspace = true } half = { workspace = true, default-features = true } diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 19ec3e7613942..8bd0cabcb05b0 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -46,6 +46,7 @@ use datafusion::{ datasource::{MemTable, TableProvider, TableType}, prelude::{CsvReadOptions, SessionContext}, }; +use datafusion_spark::SessionStateBuilderSpark; use crate::is_spark_path; use async_trait::async_trait; @@ -84,21 +85,14 @@ impl TestContext { let mut state_builder = SessionStateBuilder::new() .with_config(config) - .with_runtime_env(runtime); + .with_runtime_env(runtime) + .with_default_features(); if is_spark_path(relative_path) { - state_builder = state_builder.with_expr_planners(vec![Arc::new( - datafusion_spark::planner::SparkFunctionPlanner, - )]); + state_builder = state_builder.with_spark_features(); } - let mut state = state_builder.with_default_features().build(); - - if is_spark_path(relative_path) { - info!("Registering Spark functions"); - datafusion_spark::register_all(&mut state) - .expect("Can not register Spark functions"); - } + let state = state_builder.build(); let mut test_ctx = TestContext::new(SessionContext::new_with_state(state));