Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1541,7 +1541,7 @@ async fn join_on() -> Result<()> {
)?;

assert_snapshot!(join.logical_plan(), @r"
Inner Join: Filter: a.c1 != b.c1 AND a.c2 = b.c2
Inner Join: Filter: (a.c1 != b.c1) AND (a.c2 = b.c2)
Projection: a.c1, a.c2
TableScan: a
Projection: b.c1, b.c2
Expand Down
113 changes: 92 additions & 21 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,34 +616,24 @@ impl BinaryExpr {

impl Display for BinaryExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
// Put parentheses around child binary expressions so that we can see the difference
// between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed,
// based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are
// equivalent and the parentheses are not necessary.

fn write_child(
f: &mut Formatter<'_>,
expr: &Expr,
precedence: u8,
) -> fmt::Result {
// Put parentheses around child binary expressions to avoid ambiguity.
// For example, `(1 + 2) * 3` should display with parentheses to show
// it's different from `1 + 2 * 3`. This follows DuckDB's approach of
// always adding parentheses around nested binary expressions.

fn write_child(f: &mut Formatter<'_>, expr: &Expr) -> fmt::Result {
match expr {
Expr::BinaryExpr(child) => {
let p = child.op.precedence();
if p == 0 || p < precedence {
write!(f, "({child})")?;
} else {
write!(f, "{child}")?;
}
write!(f, "({child})")?;
}
_ => write!(f, "{expr}")?,
}
Ok(())
}

let precedence = self.op.precedence();
write_child(f, self.left.as_ref(), precedence)?;
write_child(f, self.left.as_ref())?;
write!(f, " {} ", self.op)?;
write_child(f, self.right.as_ref(), precedence)
write_child(f, self.right.as_ref())
}
}

Expand Down Expand Up @@ -2828,7 +2818,11 @@ impl Display for SchemaDisplay<'_> {
}
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
write!(f, "{} {op} {}", SchemaDisplay(left), SchemaDisplay(right),)
// Note: SchemaDisplay intentionally does NOT add parentheses around
// nested binary expressions because schema names must remain stable
// for field lookups in optimizers like common_subexpr_eliminate.
// Use Display for human-readable output with parentheses.
write!(f, "{} {op} {}", SchemaDisplay(left), SchemaDisplay(right))
}
Expr::Case(Case {
expr,
Expand Down Expand Up @@ -3090,7 +3084,16 @@ impl Display for SqlDisplay<'_> {
}
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
write!(f, "{} {op} {}", SqlDisplay(left), SqlDisplay(right),)
// Add parentheses around nested binary expressions to avoid ambiguity
fn write_child(f: &mut Formatter<'_>, expr: &Expr) -> fmt::Result {
match expr {
Expr::BinaryExpr(_) => write!(f, "({})", SqlDisplay(expr)),
_ => write!(f, "{}", SqlDisplay(expr)),
}
}
write_child(f, left)?;
write!(f, " {op} ")?;
write_child(f, right)
}
Expr::Case(Case {
expr,
Expand Down Expand Up @@ -4093,4 +4096,72 @@ mod test {
}
}
}

#[test]
fn test_binary_expr_display_with_parentheses() {
// Test that nested binary expressions display with parentheses
// to avoid ambiguity. For example, (1+2)*3 should show parentheses.
let one = lit(1i64);
let two = lit(2i64);
let three = lit(3i64);

// (1+2)*3 - addition nested in multiplication
let add = Expr::BinaryExpr(BinaryExpr::new(
Box::new(one.clone()),
Operator::Plus,
Box::new(two.clone()),
));
let mul = Expr::BinaryExpr(BinaryExpr::new(
Box::new(add),
Operator::Multiply,
Box::new(three.clone()),
));

let display = format!("{mul}");
// Should contain parentheses around the addition
assert!(
display.contains("("),
"Expected parentheses in display: {display}"
);
assert_eq!(display, "(Int64(1) + Int64(2)) * Int64(3)");

// 1*(2+3) - addition nested in multiplication on the right
let add_right = Expr::BinaryExpr(BinaryExpr::new(
Box::new(two.clone()),
Operator::Multiply,
Box::new(three.clone()),
));
let mul_right = Expr::BinaryExpr(BinaryExpr::new(
Box::new(one.clone()),
Operator::Plus,
Box::new(add_right),
));

let display_right = format!("{mul_right}");
assert!(
display_right.contains("("),
"Expected parentheses in display: {display_right}"
);
assert_eq!(display_right, "Int64(1) + (Int64(2) * Int64(3))");

// (a OR b) AND c - logical operators
let a = col("a");
let b = col("b");
let c = col("c");

let or_expr =
Expr::BinaryExpr(BinaryExpr::new(Box::new(a), Operator::Or, Box::new(b)));
let and_expr = Expr::BinaryExpr(BinaryExpr::new(
Box::new(or_expr),
Operator::And,
Box::new(c),
));

let display_logical = format!("{and_expr}");
assert!(
display_logical.contains("("),
"Expected parentheses in display: {display_logical}"
);
assert_eq!(display_logical, "(a OR b) AND c");
}
}
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1721,8 +1721,8 @@ mod test {

assert_analyzed_plan_eq!(
plan,
@r"
Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
@"
Projection: (a < CAST(UInt32(2) AS Float64)) OR (a < CAST(UInt32(2) AS Float64))
EmptyRelation: rows=0
"
)
Expand Down
52 changes: 26 additions & 26 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ mod test {

assert_optimized_plan_equal!(
plan,
@ r"
@ "
Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
TableScan: test
Expand Down Expand Up @@ -1251,9 +1251,9 @@ mod test {

assert_optimized_plan_equal!(
plan,
@ r"
@ "
Projection: test.a, test.b, test.c
Filter: __common_expr_1 - Int32(10) > __common_expr_1
Filter: (__common_expr_1 - Int32(10)) > __common_expr_1
Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
TableScan: test
"
Expand Down Expand Up @@ -1375,9 +1375,9 @@ mod test {

assert_optimized_plan_equal!(
plan,
@ r"
Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5
Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c
@ "
Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR ((test.a - test.b) = Int32(0)) AS c3, __common_expr_2 AND ((test.a - test.b) = Int32(0)) AS c4, __common_expr_3 OR __common_expr_3 AS c5
Projection: (test.a = Int32(0)) OR (test.b = Int32(0)) AS __common_expr_1, (test.a + test.b) = Int32(0) AS __common_expr_2, (test.a * test.b) = Int32(0) AS __common_expr_3, test.a, test.b, test.c
TableScan: test
"
)
Expand Down Expand Up @@ -1429,8 +1429,8 @@ mod test {

assert_optimized_plan_equal!(
plan,
@ r"
Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4
@ "
Projection: __common_expr_1 OR (random() = Int32(0)) AS c1, __common_expr_1 OR (random() = Int32(0)) AS c2, (random() = Int32(0)) OR (test.b = Int32(0)) AS c3, (random() = Int32(0)) OR (test.b = Int32(0)) AS c4
Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c
TableScan: test
"
Expand Down Expand Up @@ -1494,9 +1494,9 @@ mod test {

assert_optimized_plan_equal!(
plan,
@ r"
@ "
Projection: test.a, test.b, test.c
Filter: __common_expr_1 * __common_expr_1 = Int32(30)
Filter: (__common_expr_1 * __common_expr_1) = Int32(30)
Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
TableScan: test
"
Expand All @@ -1512,9 +1512,9 @@ mod test {

assert_optimized_plan_equal!(
plan,
@ r"
@ "
Projection: test.a, test.b, test.c
Filter: __common_expr_1 + __common_expr_1 = Int32(30)
Filter: (__common_expr_1 + __common_expr_1) = Int32(30)
Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c
TableScan: test
"
Expand All @@ -1530,9 +1530,9 @@ mod test {

assert_optimized_plan_equal!(
plan,
@ r"
@ "
Projection: test.a, test.b, test.c
Filter: __common_expr_1 + __common_expr_1 = Int32(30)
Filter: (__common_expr_1 + __common_expr_1) = Int32(30)
Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c
TableScan: test
"
Expand All @@ -1548,9 +1548,9 @@ mod test {

assert_optimized_plan_equal!(
plan,
@ r"
@ "
Projection: test.a, test.b, test.c
Filter: __common_expr_1 + __common_expr_1 = Int32(30)
Filter: (__common_expr_1 + __common_expr_1) = Int32(30)
Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c
TableScan: test
"
Expand All @@ -1566,9 +1566,9 @@ mod test {

assert_optimized_plan_equal!(
plan,
@ r"
@ "
Projection: test.a, test.b, test.c
Filter: __common_expr_1 + __common_expr_1 = Int32(30)
Filter: (__common_expr_1 + __common_expr_1) = Int32(30)
Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c
TableScan: test
"
Expand Down Expand Up @@ -1621,10 +1621,10 @@ mod test {

assert_optimized_plan_equal!(
plan,
@ r"
@ "
Projection: test.a, test.b, test.c
Filter: __common_expr_1 - __common_expr_1 = Int32(30)
Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c
Filter: (__common_expr_1 - __common_expr_1) = Int32(30)
Projection: test.a + (test.b * test.c) AS __common_expr_1, test.a, test.b, test.c
TableScan: test
"
)?;
Expand All @@ -1639,10 +1639,10 @@ mod test {

assert_optimized_plan_equal!(
plan,
@ r"
@ "
Projection: test.a, test.b, test.c
Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)
Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c
Filter: ((__common_expr_1 / __common_expr_1) + test.a) = Int32(30)
Projection: (test.a + (test.b / test.c)) * test.c AS __common_expr_1, test.a, test.b, test.c
TableScan: test
"
)?;
Expand All @@ -1655,9 +1655,9 @@ mod test {
let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
assert_optimized_plan_equal!(
plan,
@ r"
@ "
Projection: test.a, test.b, test.c
Filter: __common_expr_1 * __common_expr_1 = Int32(30)
Filter: (__common_expr_1 * __common_expr_1) = Int32(30)
Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c
TableScan: test
"
Expand Down
Loading
Loading