diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index b39b23e30f4e8..0889afd08fee4 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -208,14 +208,16 @@ pub fn check_subquery_expr( if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic plan_err!( - "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions" + "Correlated scalar subquery in the GROUP BY clause must \ + also be in the aggregate expressions" ) } else { Ok(()) } } _ => plan_err!( - "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes" + "Correlated scalar subquery can only be used in Projection, \ + Filter, Aggregate plan nodes" ), }?; } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 34fbe2edf8dd9..d6a40ceb51655 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -76,15 +76,16 @@ impl SqlToRel<'_, S> { } // Check the outer query schema - if let Some(outer) = planner_context.outer_query_schema() - && let Ok((qualifier, field)) = + for outer in planner_context.outer_queries_schemas() { + if let Ok((qualifier, field)) = outer.qualified_field_with_unqualified_name(normalize_ident.as_str()) - { - // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column - return Ok(Expr::OuterReferenceColumn( - Arc::clone(field), - Column::from((qualifier, field)), - )); + { + // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + return Ok(Expr::OuterReferenceColumn( + Arc::clone(field), + Column::from((qualifier, field)), + )); + } } // Default case @@ -172,36 +173,44 @@ impl SqlToRel<'_, S> { not_impl_err!("compound identifier: {ids:?}") } else { // Check the outer_query_schema and try to find a match - if let Some(outer) = planner_context.outer_query_schema() { - let search_result = search_dfschema(&ids, outer); - match search_result { - // Found matching field with spare identifier(s) for nested field(s) in structure - Some((field, qualifier, nested_names)) - if !nested_names.is_empty() => - { - // TODO: remove when can support nested identifiers for OuterReferenceColumn - not_impl_err!( - "Nested identifiers are not yet supported for OuterReferenceColumn {}", - Column::from((qualifier, field)) - .quoted_flat_name() - ) - } - // Found matching field with no spare identifier(s) - Some((field, qualifier, _nested_names)) => { - // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column - Ok(Expr::OuterReferenceColumn( - Arc::clone(field), - Column::from((qualifier, field)), - )) - } - // Found no matching field, will return a default - None => { - let s = &ids[0..ids.len()]; - // safe unwrap as s can never be empty or exceed the bounds - let (relation, column_name) = - form_identifier(s).unwrap(); - Ok(Expr::Column(Column::new(relation, column_name))) - } + let outer_schemas = planner_context.outer_queries_schemas(); + let mut maybe_result = None; + if !outer_schemas.is_empty() { + for outer in planner_context.outer_queries_schemas() { + let search_result = search_dfschema(&ids, &outer); + let result = match search_result { + // Found matching field with spare identifier(s) for nested field(s) in structure + Some((field, qualifier, nested_names)) + if !nested_names.is_empty() => + { + // TODO: remove when can support nested identifiers for OuterReferenceColumn + not_impl_err!( + "Nested identifiers are not yet supported for OuterReferenceColumn {}", + Column::from((qualifier, field)) + .quoted_flat_name() + ) + } + // Found matching field with no spare identifier(s) + Some((field, qualifier, _nested_names)) => { + // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + Ok(Expr::OuterReferenceColumn( + Arc::clone(field), + Column::from((qualifier, field)), + )) + } + // Found no matching field, will return a default + None => continue, + }; + maybe_result = Some(result); + break; + } + if let Some(result) = maybe_result { + result + } else { + let s = &ids[0..ids.len()]; + // safe unwrap as s can never be empty or exceed the bounds + let (relation, column_name) = form_identifier(s).unwrap(); + Ok(Expr::Column(Column::new(relation, column_name))) } } else { let s = &ids[0..ids.len()]; diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index 6837b2671cb1c..662c44f6f2620 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -31,11 +31,10 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(input_schema.clone().into()); let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); Ok(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(sub_plan), @@ -54,8 +53,7 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(Arc::new(input_schema.clone())); let mut spans = Spans::new(); if let SetExpr::Select(select) = &subquery.body.as_ref() { @@ -70,7 +68,7 @@ impl SqlToRel<'_, S> { let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); self.validate_single_column( &sub_plan, @@ -98,8 +96,7 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(Arc::new(input_schema.clone())); let mut spans = Spans::new(); if let SetExpr::Select(select) = subquery.body.as_ref() { for item in &select.projection { @@ -112,7 +109,7 @@ impl SqlToRel<'_, S> { } let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); self.validate_single_column( &sub_plan, @@ -172,8 +169,7 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(Arc::new(input_schema.clone())); let mut spans = Spans::new(); if let SetExpr::Select(select) = subquery.body.as_ref() { @@ -188,7 +184,7 @@ impl SqlToRel<'_, S> { let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); self.validate_single_column( &sub_plan, diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 520a2d55ef6a2..6a3a2cc4e5b5a 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -261,8 +261,11 @@ pub struct PlannerContext { /// Map of CTE name to logical plan of the WITH clause. /// Use `Arc` to allow cheap cloning ctes: HashMap>, - /// The query schema of the outer query plan, used to resolve the columns in subquery - outer_query_schema: Option, + + /// The queries schemas of outer query relations, used to resolve the outer referenced + /// columns in subquery (recursive aware) + outer_queries_schemas_stack: Vec, + /// The joined schemas of all FROM clauses planned so far. When planning LATERAL /// FROM clauses, this should become a suffix of the `outer_query_schema`. outer_from_schema: Option, @@ -282,7 +285,7 @@ impl PlannerContext { Self { prepare_param_data_types: Arc::new(vec![]), ctes: HashMap::new(), - outer_query_schema: None, + outer_queries_schemas_stack: vec![], outer_from_schema: None, create_table_schema: None, } @@ -297,19 +300,26 @@ impl PlannerContext { self } - // Return a reference to the outer query's schema - pub fn outer_query_schema(&self) -> Option<&DFSchema> { - self.outer_query_schema.as_ref().map(|s| s.as_ref()) + /// Return the stack of outer relations' schemas, the outer most + /// relation are at the first entry + pub fn outer_queries_schemas(&self) -> Vec { + self.outer_queries_schemas_stack.to_vec() } /// Sets the outer query schema, returning the existing one, if /// any - pub fn set_outer_query_schema( - &mut self, - mut schema: Option, - ) -> Option { - std::mem::swap(&mut self.outer_query_schema, &mut schema); - schema + pub fn append_outer_query_schema(&mut self, schema: DFSchemaRef) { + self.outer_queries_schemas_stack.push(schema); + } + + /// The schema of the adjacent outer relation + pub fn latest_outer_query_schema(&mut self) -> Option { + self.outer_queries_schemas_stack.last().cloned() + } + + /// Remove the schema of the adjacent outer relation + pub fn pop_outer_query_schema(&mut self) -> Option { + self.outer_queries_schemas_stack.pop() } pub fn set_table_schema( diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index cef3726c62e40..6558763ca4e42 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -262,9 +262,10 @@ impl SqlToRel<'_, S> { } => { let tbl_func_ref = self.object_name_to_table_reference(name)?; let schema = planner_context - .outer_query_schema() + .outer_queries_schemas() + .last() .cloned() - .unwrap_or_else(DFSchema::empty); + .unwrap_or_else(|| Arc::new(DFSchema::empty())); let func_args = args .into_iter() .map(|arg| match arg { @@ -310,20 +311,24 @@ impl SqlToRel<'_, S> { let old_from_schema = planner_context .set_outer_from_schema(None) .unwrap_or_else(|| Arc::new(DFSchema::empty())); - let new_query_schema = match planner_context.outer_query_schema() { - Some(old_query_schema) => { + let outer_query_schema = planner_context.pop_outer_query_schema(); + let new_query_schema = match outer_query_schema { + Some(ref old_query_schema) => { let mut new_query_schema = old_from_schema.as_ref().clone(); - new_query_schema.merge(old_query_schema); - Some(Arc::new(new_query_schema)) + new_query_schema.merge(old_query_schema.as_ref()); + Arc::new(new_query_schema) } - None => Some(Arc::clone(&old_from_schema)), + None => Arc::clone(&old_from_schema), }; - let old_query_schema = planner_context.set_outer_query_schema(new_query_schema); + planner_context.append_outer_query_schema(new_query_schema); let plan = self.create_relation(subquery, planner_context)?; let outer_ref_columns = plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_query_schema); + planner_context.pop_outer_query_schema(); + if let Some(schema) = outer_query_schema { + planner_context.append_outer_query_schema(schema); + } planner_context.set_outer_from_schema(Some(old_from_schema)); // We can omit the subquery wrapper if there are no columns diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 1d6ccde6be13a..182bc97ad4d98 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -29,7 +29,7 @@ use crate::utils::{ use datafusion_common::error::DataFusionErrorBuilder; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{Column, Result, not_impl_err, plan_err}; +use datafusion_common::{Column, DFSchema, Result, not_impl_err, plan_err}; use datafusion_common::{RecursionUnnestOption, UnnestOptions}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ @@ -637,12 +637,8 @@ impl SqlToRel<'_, S> { match selection { Some(predicate_expr) => { let fallback_schemas = plan.fallback_normalize_schemas(); - let outer_query_schema = planner_context.outer_query_schema().cloned(); - let outer_query_schema_vec = outer_query_schema - .as_ref() - .map(|schema| vec![schema]) - .unwrap_or_else(Vec::new); + let outer_query_schema_vec = planner_context.outer_queries_schemas(); let filter_expr = self.sql_to_expr(predicate_expr, plan.schema(), planner_context)?; @@ -657,9 +653,19 @@ impl SqlToRel<'_, S> { let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; + let mut schema_stack: Vec> = + vec![vec![plan.schema()], fallback_schemas]; + for sc in outer_query_schema_vec.iter().rev() { + schema_stack.push(vec![sc.as_ref()]); + } + let filter_expr = normalize_col_with_schemas_and_ambiguity_check( filter_expr, - &[&[plan.schema()], &fallback_schemas, &outer_query_schema_vec], + schema_stack + .iter() + .map(|sc| sc.as_slice()) + .collect::>() + .as_slice(), &[using_columns], )?; diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 9dc6b895e49ab..686fdf503f3d6 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -161,12 +161,26 @@ impl ContextProvider for MockContextProvider { ])), "orders" => Ok(Schema::new(vec![ Field::new("order_id", DataType::UInt32, false), + Field::new("o_orderkey", DataType::UInt32, false), + Field::new("o_custkey", DataType::UInt32, false), + Field::new("o_orderstatus", DataType::Utf8, false), Field::new("customer_id", DataType::UInt32, false), + Field::new("o_totalprice", DataType::Decimal32(15, 2), false), Field::new("o_item_id", DataType::Utf8, false), Field::new("qty", DataType::Int32, false), Field::new("price", DataType::Float64, false), Field::new("delivered", DataType::Boolean, false), ])), + "customer" => Ok(Schema::new(vec![ + Field::new("c_custkey", DataType::UInt32, false), + Field::new("c_name", DataType::Utf8, false), + Field::new("c_address", DataType::Utf8, false), + Field::new("c_nationkey", DataType::UInt32, false), + Field::new("c_phone", DataType::Decimal32(15, 2), false), + Field::new("c_acctbal", DataType::Float64, false), + Field::new("c_mktsegment", DataType::Utf8, false), + Field::new("c_comment", DataType::Utf8, false), + ])), "array" => Ok(Schema::new(vec![ Field::new( "left", @@ -186,8 +200,10 @@ impl ContextProvider for MockContextProvider { ), ])), "lineitem" => Ok(Schema::new(vec![ + Field::new("l_orderkey", DataType::UInt32, false), Field::new("l_item_id", DataType::UInt32, false), Field::new("l_description", DataType::Utf8, false), + Field::new("l_extendedprice", DataType::Decimal32(15, 2), false), Field::new("price", DataType::Float64, false), ])), "aggregate_test_100" => Ok(Schema::new(vec![ diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 491873b4afe02..61569ea76ed2b 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -995,15 +995,15 @@ fn select_nested_with_filters() { #[test] fn table_with_column_alias() { - let sql = "SELECT a, b, c - FROM lineitem l (a, b, c)"; + let sql = "SELECT a, b, c, d, e + FROM lineitem l (a, b, c, d, e)"; let plan = logical_plan(sql).unwrap(); assert_snapshot!( plan, @r" - Projection: l.a, l.b, l.c + Projection: l.a, l.b, l.c, l.d, l.e SubqueryAlias: l - Projection: lineitem.l_item_id AS a, lineitem.l_description AS b, lineitem.price AS c + Projection: lineitem.l_orderkey AS a, lineitem.l_item_id AS b, lineitem.l_description AS c, lineitem.l_extendedprice AS d, lineitem.price AS e TableScan: lineitem " ); @@ -1017,7 +1017,7 @@ fn table_with_column_alias_number_cols() { assert_snapshot!( err.strip_backtrace(), - @"Error during planning: Source table contains 3 columns but only 2 names given as column alias" + @"Error during planning: Source table contains 5 columns but only 2 names given as column alias" ); } @@ -1058,7 +1058,7 @@ fn natural_left_join() { plan, @r" Projection: a.l_item_id - Left Join: Using a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.price = b.price + Left Join: Using a.l_orderkey = b.l_orderkey, a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.l_extendedprice = b.l_extendedprice, a.price = b.price SubqueryAlias: a TableScan: lineitem SubqueryAlias: b @@ -1075,7 +1075,7 @@ fn natural_right_join() { plan, @r" Projection: a.l_item_id - Right Join: Using a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.price = b.price + Right Join: Using a.l_orderkey = b.l_orderkey, a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.l_extendedprice = b.l_extendedprice, a.price = b.price SubqueryAlias: a TableScan: lineitem SubqueryAlias: b @@ -4801,7 +4801,11 @@ fn test_using_join_wildcard_schema() { // Only columns from one join side should be present let expected_fields = vec![ "o1.order_id".to_string(), + "o1.o_orderkey".to_string(), + "o1.o_custkey".to_string(), + "o1.o_orderstatus".to_string(), "o1.customer_id".to_string(), + "o1.o_totalprice".to_string(), "o1.o_item_id".to_string(), "o1.qty".to_string(), "o1.price".to_string(), @@ -4855,3 +4859,138 @@ fn test_using_join_wildcard_schema() { ] ); } + +#[test] +fn test_2_nested_lateral_join_with_the_deepest_join_referencing_the_outer_most_relation() +{ + let sql = "SELECT * FROM j1 j1_outer, LATERAL ( + SELECT * FROM j1 j1_inner, LATERAL ( + SELECT * FROM j2 WHERE j1_inner.j1_id = j2_id and j1_outer.j1_id=j2_id + ) as j2 +) as j2"; + + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: j1_outer.j1_id, j1_outer.j1_string, j2.j1_id, j2.j1_string, j2.j2_id, j2.j2_string + Cross Join: + SubqueryAlias: j1_outer + TableScan: j1 + SubqueryAlias: j2 + Subquery: + Projection: j1_inner.j1_id, j1_inner.j1_string, j2.j2_id, j2.j2_string + Cross Join: + SubqueryAlias: j1_inner + TableScan: j1 + SubqueryAlias: j2 + Subquery: + Projection: j2.j2_id, j2.j2_string + Filter: outer_ref(j1_inner.j1_id) = j2.j2_id AND outer_ref(j1_outer.j1_id) = j2.j2_id + TableScan: j2 +"# + ); +} + +#[test] +fn test_correlated_recursive_scalar_subquery_with_level_3_scalar_subquery_referencing_level1_relation() + { + let sql = "select c_custkey from customer + where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice < ( + select sum(l_extendedprice) as price from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) + ) order by c_custkey"; + + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: customer.c_custkey ASC NULLS LAST + Projection: customer.c_custkey + Filter: customer.c_acctbal < () + Subquery: + Projection: sum(orders.o_totalprice) + Aggregate: groupBy=[[]], aggr=[[sum(orders.o_totalprice)]] + Filter: orders.o_custkey = outer_ref(customer.c_custkey) AND orders.o_totalprice < () + Subquery: + Projection: sum(lineitem.l_extendedprice) AS price + Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice)]] + Filter: lineitem.l_orderkey = outer_ref(orders.o_orderkey) AND lineitem.l_extendedprice < outer_ref(customer.c_acctbal) + TableScan: lineitem + TableScan: orders + TableScan: customer +"# + ); +} + +#[test] +fn correlated_recursive_scalar_subquery_with_level_3_exists_subquery_referencing_level1_relation() + { + let sql = "select c_custkey from customer + where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and exists ( + select * from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) + ) order by c_custkey"; + + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: customer.c_custkey ASC NULLS LAST + Projection: customer.c_custkey + Filter: customer.c_acctbal < () + Subquery: + Projection: sum(orders.o_totalprice) + Aggregate: groupBy=[[]], aggr=[[sum(orders.o_totalprice)]] + Filter: orders.o_custkey = outer_ref(customer.c_custkey) AND EXISTS () + Subquery: + Projection: lineitem.l_orderkey, lineitem.l_item_id, lineitem.l_description, lineitem.l_extendedprice, lineitem.price + Filter: lineitem.l_orderkey = outer_ref(orders.o_orderkey) AND lineitem.l_extendedprice < outer_ref(customer.c_acctbal) + TableScan: lineitem + TableScan: orders + TableScan: customer +"# + ); +} + +#[test] +fn correlated_recursive_scalar_subquery_with_level_3_in_subquery_referencing_level1_relation() + { + let sql = "select c_custkey from customer + where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice in ( + select l_extendedprice as price from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) + ) order by c_custkey"; + + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: customer.c_custkey ASC NULLS LAST + Projection: customer.c_custkey + Filter: customer.c_acctbal < () + Subquery: + Projection: sum(orders.o_totalprice) + Aggregate: groupBy=[[]], aggr=[[sum(orders.o_totalprice)]] + Filter: orders.o_custkey = outer_ref(customer.c_custkey) AND orders.o_totalprice IN () + Subquery: + Projection: lineitem.l_extendedprice AS price + Filter: lineitem.l_orderkey = outer_ref(orders.o_orderkey) AND lineitem.l_extendedprice < outer_ref(customer.c_acctbal) + TableScan: lineitem + TableScan: orders + TableScan: customer +"# + ); +}