From 24a9429a3edd3bec311bf418d0aaed726fab8a1f Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 14:06:35 -0400 Subject: [PATCH 01/13] migrate tests in `replace_distinct_aggregate.rs` --- .../src/replace_distinct_aggregate.rs | 37 +++++++++---------- datafusion/optimizer/src/test/mod.rs | 18 +++++++++ 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 48b2828faf452..9f7a15d493563 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -186,21 +186,26 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { mod tests { use std::sync::Arc; + use crate::assert_optimized_plan_eq_snapshot; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::test::*; use datafusion_common::Result; - use datafusion_expr::{ - col, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan, - }; + use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder, Expr}; use datafusion_functions_aggregate::sum::sum; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq( - Arc::new(ReplaceDistinctWithAggregate::new()), - plan.clone(), - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(ReplaceDistinctWithAggregate::new()); + assert_optimized_plan_eq_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[test] @@ -212,8 +217,7 @@ mod tests { .distinct()? .build()?; - let expected = "Projection: test.c\n Aggregate: groupBy=[[test.c]], aggr=[[]]\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @ "Projection: test.c\n Aggregate: groupBy=[[test.c]], aggr=[[]]\n TableScan: test") } #[test] @@ -225,9 +229,7 @@ mod tests { .distinct()? .build()?; - let expected = - "Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @ "Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test") } #[test] @@ -238,8 +240,7 @@ mod tests { .distinct()? .build()?; - let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @ "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n TableScan: test") } #[test] @@ -251,8 +252,6 @@ mod tests { .distinct()? .build()?; - let expected = - "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @ "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]\n TableScan: test") } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 94d07a0791b3b..cf501dce72417 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -181,6 +181,24 @@ pub fn assert_optimized_plan_eq( Ok(()) } +#[macro_export] +macro_rules! assert_optimized_plan_eq_snapshot { + ( + $rule:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + // Apply the rule once + let opt_context = crate::OptimizerContext::new().with_max_passes(1); + + let optimizer = crate::Optimizer::with_rules(vec![Arc::clone(&$rule)]); + let optimized_plan = optimizer.optimize($plan, &opt_context, |_, _| {})?; + insta::assert_snapshot!(optimized_plan, @ $expected); + + Ok(()) + }}; +} + fn generate_optimized_plan_with_rules( rules: Vec>, plan: LogicalPlan, From e7c7b4deee452add73bdd7d23c2bf75c5a469079 Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 14:38:18 -0400 Subject: [PATCH 02/13] migrate tests in `replace_distinct_aggregate.rs` --- .../src/replace_distinct_aggregate.rs | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 9f7a15d493563..c7c9d03a51ae7 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -217,7 +217,11 @@ mod tests { .distinct()? .build()?; - assert_optimized_plan_equal!(plan, @ "Projection: test.c\n Aggregate: groupBy=[[test.c]], aggr=[[]]\n TableScan: test") + assert_optimized_plan_equal!(plan, @r" + Projection: test.c + Aggregate: groupBy=[[test.c]], aggr=[[]] + TableScan: test + ") } #[test] @@ -229,7 +233,11 @@ mod tests { .distinct()? .build()?; - assert_optimized_plan_equal!(plan, @ "Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test") + assert_optimized_plan_equal!(plan, @r" + Projection: test.a, test.b + Aggregate: groupBy=[[test.a, test.b]], aggr=[[]] + TableScan: test + ") } #[test] @@ -240,7 +248,11 @@ mod tests { .distinct()? .build()?; - assert_optimized_plan_equal!(plan, @ "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n TableScan: test") + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b]], aggr=[[]] + Projection: test.a, test.b + TableScan: test + ") } #[test] @@ -252,6 +264,11 @@ mod tests { .distinct()? .build()?; - assert_optimized_plan_equal!(plan, @ "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]\n TableScan: test") + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b]], aggr=[[]] + Projection: test.a, test.b + Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]] + TableScan: test + ") } } From 142890ad4bfe99ed085f85a9922904e898ad67dd Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 14:39:13 -0400 Subject: [PATCH 03/13] migrate tests in `push_down_limit.rs` --- datafusion/optimizer/src/push_down_limit.rs | 482 ++++++++++++-------- 1 file changed, 293 insertions(+), 189 deletions(-) diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 1e9ef16bde675..0ed4e05d8594f 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -276,6 +276,7 @@ mod test { use std::vec; use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; use datafusion_common::DFSchemaRef; @@ -285,8 +286,18 @@ mod test { }; use datafusion_functions_aggregate::expr_fn::max; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(PushDownLimit::new()); + assert_optimized_plan_eq_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[derive(Debug, PartialEq, Eq, Hash)] @@ -408,12 +419,15 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n NoopPlan\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + NoopPlan + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -430,12 +444,15 @@ mod test { .limit(10, Some(1000))? .build()?; - let expected = "Limit: skip=10, fetch=1000\ - \n NoopPlan\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + NoopPlan + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -453,12 +470,15 @@ mod test { .limit(20, Some(500))? .build()?; - let expected = "Limit: skip=30, fetch=500\ - \n NoopPlan\ - \n Limit: skip=0, fetch=530\ - \n TableScan: test, fetch=530"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=30, fetch=500 + NoopPlan + Limit: skip=0, fetch=530 + TableScan: test, fetch=530 + " + ) } #[test] @@ -475,14 +495,17 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n NoopPlan\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + NoopPlan + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -499,11 +522,14 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n NoLimitNoopPlan\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + NoLimitNoopPlan + TableScan: test + " + ) } #[test] @@ -517,11 +543,14 @@ mod test { // Should push the limit down to table provider // When it has a select - let expected = "Projection: test.a\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -536,10 +565,13 @@ mod test { // Should push down the smallest limit // Towards table scan // This rule doesn't replace multiple limits - let expected = "Limit: skip=0, fetch=10\ - \n TableScan: test, fetch=10"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=10 + TableScan: test, fetch=10 + " + ) } #[test] @@ -552,11 +584,14 @@ mod test { .build()?; // Limit should *not* push down aggregate node - let expected = "Limit: skip=0, fetch=1000\ - \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + TableScan: test + " + ) } #[test] @@ -569,14 +604,17 @@ mod test { .build()?; // Limit should push down through union - let expected = "Limit: skip=0, fetch=1000\ - \n Union\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Union + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -589,11 +627,14 @@ mod test { .build()?; // Should push down limit to sort - let expected = "Limit: skip=0, fetch=10\ - \n Sort: test.a ASC NULLS LAST, fetch=10\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=10 + Sort: test.a ASC NULLS LAST, fetch=10 + TableScan: test + " + ) } #[test] @@ -606,11 +647,14 @@ mod test { .build()?; // Should push down limit to sort - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a ASC NULLS LAST, fetch=15\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=5, fetch=10 + Sort: test.a ASC NULLS LAST, fetch=15 + TableScan: test + " + ) } #[test] @@ -624,12 +668,15 @@ mod test { .build()?; // Limit should use deeper LIMIT 1000, but Limit 10 shouldn't push down aggregation - let expected = "Limit: skip=0, fetch=10\ - \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=10 + Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -641,10 +688,13 @@ mod test { // Should not push any limit down to table provider // When it has a select - let expected = "Limit: skip=10, fetch=None\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=None + TableScan: test + " + ) } #[test] @@ -658,11 +708,14 @@ mod test { // Should push the limit down to table provider // When it has a select - let expected = "Projection: test.a\ - \n Limit: skip=10, fetch=1000\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=10, fetch=1000 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -675,11 +728,14 @@ mod test { .limit(10, None)? .build()?; - let expected = "Projection: test.a\ - \n Limit: skip=10, fetch=990\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=10, fetch=990 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -692,11 +748,14 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Projection: test.a\ - \n Limit: skip=10, fetch=1000\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=10, fetch=1000 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -709,10 +768,13 @@ mod test { .limit(0, Some(10))? .build()?; - let expected = "Limit: skip=10, fetch=10\ - \n TableScan: test, fetch=20"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=10 + TableScan: test, fetch=20 + " + ) } #[test] @@ -725,11 +787,14 @@ mod test { .build()?; // Limit should *not* push down aggregate node - let expected = "Limit: skip=10, fetch=1000\ - \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + TableScan: test + " + ) } #[test] @@ -742,14 +807,17 @@ mod test { .build()?; // Limit should push down through union - let expected = "Limit: skip=10, fetch=1000\ - \n Union\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Union + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -768,12 +836,15 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=10, fetch=1000\ - \n Inner Join: test.a = test2.a\ - \n TableScan: test\ - \n TableScan: test2"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Inner Join: test.a = test2.a + TableScan: test + TableScan: test2 + " + ) } #[test] @@ -792,12 +863,15 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=10, fetch=1000\ - \n Inner Join: test.a = test2.a\ - \n TableScan: test\ - \n TableScan: test2"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Inner Join: test.a = test2.a + TableScan: test + TableScan: test2 + " + ) } #[test] @@ -817,16 +891,19 @@ mod test { .build()?; // Limit pushdown Not supported in sub_query - let expected = "Limit: skip=10, fetch=100\ - \n Filter: EXISTS ()\ - \n Subquery:\ - \n Filter: test1.a = test1.a\ - \n Projection: test1.a\ - \n TableScan: test1\ - \n Projection: test2.a\ - \n TableScan: test2"; - - assert_optimized_plan_equal(outer_query, expected) + assert_optimized_plan_equal!( + outer_query, + @r" + Limit: skip=10, fetch=100 + Filter: EXISTS () + Subquery: + Filter: test1.a = test1.a + Projection: test1.a + TableScan: test1 + Projection: test2.a + TableScan: test2 + " + ) } #[test] @@ -846,16 +923,19 @@ mod test { .build()?; // Limit pushdown Not supported in sub_query - let expected = "Limit: skip=10, fetch=100\ - \n Filter: EXISTS ()\ - \n Subquery:\ - \n Filter: test1.a = test1.a\ - \n Projection: test1.a\ - \n TableScan: test1\ - \n Projection: test2.a\ - \n TableScan: test2"; - - assert_optimized_plan_equal(outer_query, expected) + assert_optimized_plan_equal!( + outer_query, + @r" + Limit: skip=10, fetch=100 + Filter: EXISTS () + Subquery: + Filter: test1.a = test1.a + Projection: test1.a + TableScan: test1 + Projection: test2.a + TableScan: test2 + " + ) } #[test] @@ -874,13 +954,16 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=10, fetch=1000\ - \n Left Join: test.a = test2.a\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010\ - \n TableScan: test2"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Left Join: test.a = test2.a + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + TableScan: test2 + " + ) } #[test] @@ -899,13 +982,16 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=0, fetch=1000\ - \n Right Join: test.a = test2.a\ - \n TableScan: test\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test2, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Right Join: test.a = test2.a + TableScan: test + Limit: skip=0, fetch=1000 + TableScan: test2, fetch=1000 + " + ) } #[test] @@ -924,13 +1010,16 @@ mod test { .build()?; // Limit pushdown with offset supported in right outer join - let expected = "Limit: skip=10, fetch=1000\ - \n Right Join: test.a = test2.a\ - \n TableScan: test\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test2, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Right Join: test.a = test2.a + TableScan: test + Limit: skip=0, fetch=1010 + TableScan: test2, fetch=1010 + " + ) } #[test] @@ -943,14 +1032,17 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n Cross Join: \ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test2, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Cross Join: + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test2, fetch=1000 + " + ) } #[test] @@ -963,14 +1055,17 @@ mod test { .limit(1000, Some(1000))? .build()?; - let expected = "Limit: skip=1000, fetch=1000\ - \n Cross Join: \ - \n Limit: skip=0, fetch=2000\ - \n TableScan: test, fetch=2000\ - \n Limit: skip=0, fetch=2000\ - \n TableScan: test2, fetch=2000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=1000, fetch=1000 + Cross Join: + Limit: skip=0, fetch=2000 + TableScan: test, fetch=2000 + Limit: skip=0, fetch=2000 + TableScan: test2, fetch=2000 + " + ) } #[test] @@ -982,10 +1077,13 @@ mod test { .limit(1000, None)? .build()?; - let expected = "Limit: skip=1000, fetch=0\ - \n TableScan: test, fetch=0"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=1000, fetch=0 + TableScan: test, fetch=0 + " + ) } #[test] @@ -997,10 +1095,13 @@ mod test { .limit(1000, None)? .build()?; - let expected = "Limit: skip=1000, fetch=0\ - \n TableScan: test, fetch=0"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=1000, fetch=0 + TableScan: test, fetch=0 + " + ) } #[test] @@ -1013,10 +1114,13 @@ mod test { .limit(1000, None)? .build()?; - let expected = "SubqueryAlias: a\ - \n Limit: skip=1000, fetch=0\ - \n TableScan: test, fetch=0"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + SubqueryAlias: a + Limit: skip=1000, fetch=0 + TableScan: test, fetch=0 + " + ) } } From 82786381d96ee6e101177a12f672c20b795c4b95 Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 14:44:53 -0400 Subject: [PATCH 04/13] migrate tests in `eliminate_duplicated_expr.rs` --- .../src/eliminate_duplicated_expr.rs | 39 ++++++++++++------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 4669500920956..6a5b29062e948 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -118,16 +118,23 @@ impl OptimizerRule for EliminateDuplicatedExpr { #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - crate::test::assert_optimized_plan_eq( - Arc::new(EliminateDuplicatedExpr::new()), - plan, - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(EliminateDuplicatedExpr::new()); + assert_optimized_plan_eq_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[test] @@ -137,10 +144,12 @@ mod tests { .sort_by(vec![col("a"), col("a"), col("b"), col("c")])? .limit(5, Some(10))? .build()?; - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Limit: skip=5, fetch=10 + Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST + TableScan: test + ") } #[test] @@ -156,9 +165,11 @@ mod tests { .sort(sort_exprs)? .limit(5, Some(10))? .build()?; - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Limit: skip=5, fetch=10 + Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST + TableScan: test + ") } } From 25ce062d24ac9d0d5d18c49b2bf706d55995a896 Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 14:48:14 -0400 Subject: [PATCH 05/13] migrate tests in `eliminate_filter.rs` --- datafusion/optimizer/src/eliminate_filter.rs | 63 ++++++++++++-------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 4ed2ac8ba1a4e..db2136e5e4e5e 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -81,17 +81,26 @@ impl OptimizerRule for EliminateFilter { mod tests { use std::sync::Arc; + use crate::assert_optimized_plan_eq_snapshot; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ - col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan, - }; + use datafusion_expr::{col, lit, logical_plan::builder::LogicalPlanBuilder, Expr}; use crate::eliminate_filter::EliminateFilter; use crate::test::*; use datafusion_expr::test::function_stub::sum; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(EliminateFilter::new()); + assert_optimized_plan_eq_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[test] @@ -105,8 +114,7 @@ mod tests { .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation") } #[test] @@ -120,8 +128,7 @@ mod tests { .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation") } #[test] @@ -139,11 +146,12 @@ mod tests { .build()?; // Left side is removed - let expected = "Union\ - \n EmptyRelation\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + EmptyRelation + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + ") } #[test] @@ -156,9 +164,10 @@ mod tests { .filter(filter_expr)? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + ") } #[test] @@ -176,12 +185,13 @@ mod tests { .build()?; // Filter is removed - let expected = "Union\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + ") } #[test] @@ -202,8 +212,9 @@ mod tests { .build()?; // Filter is removed - let expected = "Projection: test.a\ - \n EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Projection: test.a + EmptyRelation + ") } } From f7791be4a559ce8c26b9d099a0ce6941239de6b6 Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 16:59:33 -0400 Subject: [PATCH 06/13] migrate tests in `eliminate_group_by_constant.rs` to insta --- .../src/eliminate_group_by_constant.rs | 121 +++++++----------- 1 file changed, 47 insertions(+), 74 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 7e252d6dcea0e..bd5e6910201cc 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -115,6 +115,7 @@ fn is_constant_expression(expr: &Expr) -> bool { #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; use arrow::datatypes::DataType; @@ -129,6 +130,20 @@ mod tests { use std::sync::Arc; + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(EliminateGroupByConstant::new()); + assert_optimized_plan_eq_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; + } + #[derive(Debug)] struct ScalarUDFMock { signature: Signature, @@ -167,17 +182,11 @@ mod tests { .aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])? .build()?; - let expected = "\ - Projection: test.a, UInt32(1), count(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Projection: test.a, UInt32(1), count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -187,17 +196,11 @@ mod tests { .aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])? .build()?; - let expected = "\ - Projection: Utf8(\"test\"), UInt32(123), count(test.c)\ - \n Aggregate: groupBy=[[]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r#" + Projection: Utf8("test"), UInt32(123), count(test.c) + Aggregate: groupBy=[[]], aggr=[[count(test.c)]] + TableScan: test + "#) } #[test] @@ -207,16 +210,10 @@ mod tests { .aggregate(vec![col("a"), col("b")], vec![count(col("c"))])? .build()?; - let expected = "\ - Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -226,16 +223,10 @@ mod tests { .aggregate(vec![lit(123u32)], Vec::::new())? .build()?; - let expected = "\ - Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[UInt32(123)]], aggr=[[]] + TableScan: test + ") } #[test] @@ -248,17 +239,11 @@ mod tests { )? .build()?; - let expected = "\ - Projection: UInt32(123) AS const, test.a, count(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Projection: UInt32(123) AS const, test.a, count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -273,17 +258,11 @@ mod tests { .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? .build()?; - let expected = "\ - Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -298,15 +277,9 @@ mod tests { .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? .build()?; - let expected = "\ - Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } } From ea3005b05d8edbf5e4dfe00f16b7b75cfc68439c Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 17:09:41 -0400 Subject: [PATCH 07/13] migrate tests in `eliminate_join.rs` to use snapshot assertions --- datafusion/optimizer/src/eliminate_join.rs | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 789235595dabf..bac82a2ee1316 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -74,15 +74,25 @@ impl OptimizerRule for EliminateJoin { #[cfg(test)] mod tests { + use crate::assert_optimized_plan_eq_snapshot; use crate::eliminate_join::EliminateJoin; - use crate::test::*; use datafusion_common::Result; use datafusion_expr::JoinType::Inner; - use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan}; + use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateJoin::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @$expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(EliminateJoin::new()); + assert_optimized_plan_eq_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[test] @@ -95,7 +105,6 @@ mod tests { )? .build()?; - let expected = "EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation") } } From d1ea368b53b3eac518fe4895eb3e5dd954745832 Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 17:16:33 -0400 Subject: [PATCH 08/13] migrate tests in `eliminate_nested_union.rs` to use snapshot assertions --- .../optimizer/src/eliminate_nested_union.rs | 174 ++++++++++-------- 1 file changed, 94 insertions(+), 80 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 94da08243d78f..fe835afbaa542 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -116,7 +116,7 @@ mod tests { use super::*; use crate::analyzer::type_coercion::TypeCoercion; use crate::analyzer::Analyzer; - use crate::test::*; + use crate::assert_optimized_plan_eq_snapshot; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{col, logical_plan::table_scan}; @@ -129,15 +129,21 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]) - .execute_and_check(plan, &options, |_, _| {})?; - assert_optimized_plan_eq( - Arc::new(EliminateNestedUnion::new()), - analyzed_plan, - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let options = ConfigOptions::default(); + let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]) + .execute_and_check($plan, &options, |_, _| {})?; + let rule: Arc = Arc::new(EliminateNestedUnion::new()); + assert_optimized_plan_eq_snapshot!( + rule, + analyzed_plan, + @ $expected, + ) + }}; } #[test] @@ -146,11 +152,11 @@ mod tests { let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?; - let expected = "\ - Union\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table + TableScan: table + ") } #[test] @@ -162,11 +168,12 @@ mod tests { .union_distinct(plan_builder.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + TableScan: table + ") } #[test] @@ -180,13 +187,13 @@ mod tests { .union(plan_builder.build()?)? .build()?; - let expected = "\ - Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } #[test] @@ -200,14 +207,15 @@ mod tests { .union(plan_builder.build()?)? .build()?; - let expected = "Union\ - \n Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + Distinct: + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } #[test] @@ -222,14 +230,15 @@ mod tests { .union_distinct(plan_builder.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } #[test] @@ -243,13 +252,14 @@ mod tests { .union_distinct(plan_builder.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } // We don't need to use project_with_column_index in logical optimizer, @@ -273,13 +283,14 @@ mod tests { )? .build()?; - let expected = "Union\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + ") } #[test] @@ -301,14 +312,15 @@ mod tests { )? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + ") } #[test] @@ -348,13 +360,14 @@ mod tests { .union(table_3.build()?)? .build()?; - let expected = "Union\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + ") } #[test] @@ -394,13 +407,14 @@ mod tests { .union_distinct(table_3.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + ") } } From d0c76dab1138074576f277ceb827dd88c38db94d Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 17:20:18 -0400 Subject: [PATCH 09/13] migrate tests in `eliminate_outer_join.rs` to use snapshot assertions --- .../optimizer/src/eliminate_outer_join.rs | 80 +++++++++++-------- 1 file changed, 48 insertions(+), 32 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 1ecb32ca2a435..704a9e7e53414 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -304,6 +304,7 @@ fn extract_non_nullable_columns( #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; use arrow::datatypes::DataType; use datafusion_expr::{ @@ -313,8 +314,18 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @$expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(EliminateOuterJoin::new()); + assert_optimized_plan_eq_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[test] @@ -332,12 +343,13 @@ mod tests { )? .filter(col("t2.b").is_null())? .build()?; - let expected = "\ - Filter: t2.b IS NULL\ - \n Left Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IS NULL + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -355,12 +367,13 @@ mod tests { )? .filter(col("t2.b").is_not_null())? .build()?; - let expected = "\ - Filter: t2.b IS NOT NULL\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IS NOT NULL + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -382,12 +395,13 @@ mod tests { col("t1.c").lt(lit(20u32)), ))? .build()?; - let expected = "\ - Filter: t1.b > UInt32(10) OR t1.c < UInt32(20)\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b > UInt32(10) OR t1.c < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -409,12 +423,13 @@ mod tests { col("t2.c").lt(lit(20u32)), ))? .build()?; - let expected = "\ - Filter: t1.b > UInt32(10) AND t2.c < UInt32(20)\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b > UInt32(10) AND t2.c < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -436,11 +451,12 @@ mod tests { try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)), ))? .build()?; - let expected = "\ - Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20)\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } } From 7ad8be4e4b62646e8cedf55b155362b9345ffdf0 Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 17:51:45 -0400 Subject: [PATCH 10/13] migrate tests in `filter_null_join_keys.rs` to use snapshot assertions --- .../optimizer/src/filter_null_join_keys.rs | 142 +++++++++++------- 1 file changed, 88 insertions(+), 54 deletions(-) diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 2e7a751ca4c57..314b439cb51ee 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -107,35 +107,49 @@ fn create_not_null_predicate(filters: Vec) -> Expr { #[cfg(test)] mod tests { use super::*; - use crate::test::assert_optimized_plan_eq; + use crate::assert_optimized_plan_eq_snapshot; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Column; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder}; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(FilterNullJoinKeys {}); + assert_optimized_plan_eq_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[test] fn left_nullable() -> Result<()> { let (t1, t2) = test_tables()?; let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?; - let expected = "Inner Join: t1.optional_id = t2.id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id = t2.id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] fn left_nullable_left_join() -> Result<()> { let (t1, t2) = test_tables()?; let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Left)?; - let expected = "Left Join: t1.optional_id = t2.id\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Left Join: t1.optional_id = t2.id + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -144,22 +158,26 @@ mod tests { // Note: order of tables is reversed let plan = build_plan(t_right, t_left, "t2.id", "t1.optional_id", JoinType::Left)?; - let expected = "Left Join: t2.id = t1.optional_id\ - \n TableScan: t2\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Left Join: t2.id = t1.optional_id + TableScan: t2 + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + ") } #[test] fn left_nullable_on_condition_reversed() -> Result<()> { let (t1, t2) = test_tables()?; let plan = build_plan(t1, t2, "t2.id", "t1.optional_id", JoinType::Inner)?; - let expected = "Inner Join: t1.optional_id = t2.id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id = t2.id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -189,14 +207,16 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id\ - \n Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL\ - \n TableScan: t3\ - \n Inner Join: t1.optional_id = t2.id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id + Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL + TableScan: t3 + Inner Join: t1.optional_id = t2.id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -213,11 +233,13 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t1.optional_id + UInt32(1) = t2.id + UInt32(1)\ - \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id + UInt32(1) = t2.id + UInt32(1) + Filter: t1.optional_id + UInt32(1) IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -234,11 +256,13 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t1.id + UInt32(1) = t2.optional_id + UInt32(1)\ - \n TableScan: t1\ - \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.id + UInt32(1) = t2.optional_id + UInt32(1) + TableScan: t1 + Filter: t2.optional_id + UInt32(1) IS NOT NULL + TableScan: t2 + ") } #[test] @@ -255,13 +279,14 @@ mod tests { None, )? .build()?; - let expected = - "Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1)\ - \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t1\ - \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1) + Filter: t1.optional_id + UInt32(1) IS NOT NULL + TableScan: t1 + Filter: t2.optional_id + UInt32(1) IS NOT NULL + TableScan: t2 + ") } #[test] @@ -283,13 +308,22 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t1.optional_id = t2.optional_id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n Filter: t2.optional_id IS NOT NULL\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan_from_cols, expected)?; - assert_optimized_plan_equal(plan_from_exprs, expected) + + assert_optimized_plan_equal!(plan_from_cols, @r" + Inner Join: t1.optional_id = t2.optional_id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + Filter: t2.optional_id IS NOT NULL + TableScan: t2 + ")?; + + assert_optimized_plan_equal!(plan_from_exprs, @r" + Inner Join: t1.optional_id = t2.optional_id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + Filter: t2.optional_id IS NOT NULL + TableScan: t2 + ") } fn build_plan( From 4d2b935d62193a55a2797af527a3134e763dd426 Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 17:57:28 -0400 Subject: [PATCH 11/13] fix Type inferance --- datafusion/optimizer/src/test/mod.rs | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index cf501dce72417..4d387d5e4b3b0 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -195,7 +195,7 @@ macro_rules! assert_optimized_plan_eq_snapshot { let optimized_plan = optimizer.optimize($plan, &opt_context, |_, _| {})?; insta::assert_snapshot!(optimized_plan, @ $expected); - Ok(()) + Ok::<(), datafusion_common::DataFusionError>(()) }}; } @@ -229,6 +229,24 @@ pub fn assert_optimized_plan_with_rules( Ok(()) } +// #[macro_export] +// macro_rules! assert_optimized_plan_with_rules_snapshot { +// ( +// $rule:expr, +// $plan:expr, +// @ $expected:literal, +// $eq:expr $(,)? +// ) => {{ +// let optimized_plan = generate_optimized_plan_with_rules(rules, plan); +// if eq { +// assert_eq!(formatted_plan, expected); +// } else { +// assert_ne!(formatted_plan, expected); +// } +// Ok(()) +// }}; +// } + pub fn assert_optimized_plan_eq_display_indent( rule: Arc, plan: LogicalPlan, From 0f50f9a4ce646156db94596f4db84acce7c67425 Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 18:39:26 -0400 Subject: [PATCH 12/13] fix macro to use crate path for OptimizerContext and Optimizer --- datafusion/optimizer/src/test/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 4d387d5e4b3b0..3b50fda862d0e 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -189,9 +189,9 @@ macro_rules! assert_optimized_plan_eq_snapshot { @ $expected:literal $(,)? ) => {{ // Apply the rule once - let opt_context = crate::OptimizerContext::new().with_max_passes(1); + let opt_context = $crate::OptimizerContext::new().with_max_passes(1); - let optimizer = crate::Optimizer::with_rules(vec![Arc::clone(&$rule)]); + let optimizer = $crate::Optimizer::with_rules(vec![Arc::clone(&$rule)]); let optimized_plan = optimizer.optimize($plan, &opt_context, |_, _| {})?; insta::assert_snapshot!(optimized_plan, @ $expected); From 0e264c1bb06bbc5909b06f09e729816a30aaaa8b Mon Sep 17 00:00:00 2001 From: qstommyshu Date: Mon, 28 Apr 2025 19:04:24 -0400 Subject: [PATCH 13/13] clean up --- datafusion/optimizer/src/test/mod.rs | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 3b50fda862d0e..5927ffea5ae24 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -229,24 +229,6 @@ pub fn assert_optimized_plan_with_rules( Ok(()) } -// #[macro_export] -// macro_rules! assert_optimized_plan_with_rules_snapshot { -// ( -// $rule:expr, -// $plan:expr, -// @ $expected:literal, -// $eq:expr $(,)? -// ) => {{ -// let optimized_plan = generate_optimized_plan_with_rules(rules, plan); -// if eq { -// assert_eq!(formatted_plan, expected); -// } else { -// assert_ne!(formatted_plan, expected); -// } -// Ok(()) -// }}; -// } - pub fn assert_optimized_plan_eq_display_indent( rule: Arc, plan: LogicalPlan,