From b08eef9742fd6859628598550837d8bcec31319b Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Mon, 26 Dec 2022 19:08:43 +0800 Subject: [PATCH 1/7] fix: do not create projection plan manually Signed-off-by: Ruihang Xia --- datafusion/core/tests/sql/predicates.rs | 4 +- datafusion/expr/src/expr_schema.rs | 21 +++-- datafusion/expr/src/logical_plan/plan.rs | 1 + .../optimizer/src/common_subexpr_eliminate.rs | 89 +++++++++++++++---- 4 files changed, 88 insertions(+), 27 deletions(-) diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs index 94d3e06149f42..d56f95e5513d6 100644 --- a/datafusion/core/tests/sql/predicates.rs +++ b/datafusion/core/tests/sql/predicates.rs @@ -591,8 +591,8 @@ async fn multiple_or_predicates() -> Result<()> { " Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " Projection: lineitem.l_partkey, lineitem.l_quantity [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", - " Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) [lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2)]", - " Projection: lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(3000),15,2) AS lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(2000),15,2) AS lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity, lineitem.l_partkey, lineitem.l_quantity [lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity:Boolean;N, lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2)]", + " Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity) AND (lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity) [lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean, l_partkey:Int64, l_quantity:Decimal128(15, 2)]", + " Projection: lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity, lineitem.l_quantity <= Decimal128(Some(3000),15,2) AS lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity, lineitem.l_quantity >= Decimal128(Some(2000),15,2) AS lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity, lineitem.l_partkey, lineitem.l_quantity [lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity <= Decimal128(Some(3000),15,2)Decimal128(Some(3000),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2)lineitem.l_quantity <= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2)lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantitylineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity:Boolean, lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity:Boolean, l_partkey:Int64, l_quantity:Decimal128(15, 2)]", " TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2), lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) AS lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AS lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AS lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", " Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index c1a625cf47c22..8ccc884fc4aac 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -235,6 +235,7 @@ impl ExprSchemable for Expr { /// Returns a [arrow::datatypes::Field] compatible with this expression. fn to_field(&self, input_schema: &DFSchema) -> Result { + println!("to_field: {:?}", self); match self { Expr::Column(c) => Ok(DFField::new( c.relation.as_deref(), @@ -242,12 +243,20 @@ impl ExprSchemable for Expr { self.get_type(input_schema)?, self.nullable(input_schema)?, )), - _ => Ok(DFField::new( - None, - &self.display_name()?, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - )), + _ => { + let name = &self.display_name()?; + println!("name: {:?}", name); + let data_type = self.get_type(input_schema)?; + println!("data type: {:?}", data_type); + let nullable = self.nullable(input_schema)?; + println!("nullable: {:?}", nullable); + Ok(DFField::new( + None, + &self.display_name()?, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + )) + } } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 23f26ad8d1f5e..6365022d918e3 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1256,6 +1256,7 @@ impl Projection { exprlist_to_fields(&expr, &input)?, input.schema().metadata().clone(), )?); + println!("generated schema: {schema:?}"); Self::try_new_with_schema(expr, input, schema) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index a8c9f5d867bac..59f00bd8187fe 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::DataType; -use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, Result}; +use datafusion_common::{DFField, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::{ col, expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}, @@ -295,7 +295,6 @@ fn build_project_plan( expr_set: &ExprSet, ) -> Result { let mut project_exprs = vec![]; - let mut fields = vec![]; let mut fields_set = BTreeSet::new(); for id in affected_id { @@ -304,7 +303,6 @@ fn build_project_plan( // todo: check `nullable` let field = DFField::new(None, &id, data_type.clone(), true); fields_set.insert(field.name().to_owned()); - fields.push(field); project_exprs.push(expr.clone().alias(&id)); } _ => { @@ -317,17 +315,13 @@ fn build_project_plan( for field in input.schema().fields() { if fields_set.insert(field.qualified_name()) { - fields.push(field.clone()); project_exprs.push(Expr::Column(field.qualified_column())); } } - let schema = DFSchema::new_with_metadata(fields, HashMap::new())?; - - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( + Ok(LogicalPlan::Projection(Projection::try_new( project_exprs, Arc::new(input), - Arc::new(schema), )?)) } @@ -567,6 +561,7 @@ mod test { use arrow::datatypes::{Field, Schema}; + use datafusion_common::DFSchema; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ avg, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, sum, @@ -762,16 +757,28 @@ mod test { fn redundant_project_fields() { let table_scan = test_table_scan().unwrap(); let affected_id: BTreeSet = - ["c+a".to_string(), "d+a".to_string()].into_iter().collect(); - let expr_set = [ + ["c+a".to_string(), "b+a".to_string()].into_iter().collect(); + let expr_set_1 = [ + ( + "c+a".to_string(), + (col("c") + col("a"), 1, DataType::UInt32), + ), + ( + "b+a".to_string(), + (col("b") + col("a"), 1, DataType::UInt32), + ), + ] + .into_iter() + .collect(); + let expr_set_2 = [ ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)), - ("d+a".to_string(), (col("d+a"), 1, DataType::UInt32)), + ("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)), ] .into_iter() .collect(); let project = - build_project_plan(table_scan, affected_id.clone(), &expr_set).unwrap(); - let project_2 = build_project_plan(project, affected_id, &expr_set).unwrap(); + build_project_plan(table_scan, affected_id.clone(), &expr_set_1).unwrap(); + let project_2 = build_project_plan(project, affected_id, &expr_set_2).unwrap(); let mut field_set = BTreeSet::new(); for field in project_2.schema().fields() { @@ -789,15 +796,35 @@ mod test { .build() .unwrap(); let affected_id: BTreeSet = - ["c+a".to_string(), "d+a".to_string()].into_iter().collect(); - let expr_set = [ - ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)), - ("d+a".to_string(), (col("d+a"), 1, DataType::UInt32)), + ["test1.c+test1.a".to_string(), "test1.b+test1.a".to_string()] + .into_iter() + .collect(); + let expr_set_1 = [ + ( + "test1.c+test1.a".to_string(), + (col("test1.c") + col("test1.a"), 1, DataType::UInt32), + ), + ( + "test1.b+test1.a".to_string(), + (col("test1.b") + col("test1.a"), 1, DataType::UInt32), + ), + ] + .into_iter() + .collect(); + let expr_set_2 = [ + ( + "test1.c+test1.a".to_string(), + (col("test1.c+test1.a"), 1, DataType::UInt32), + ), + ( + "test1.b+test1.a".to_string(), + (col("test1.b+test1.a"), 1, DataType::UInt32), + ), ] .into_iter() .collect(); - let project = build_project_plan(join, affected_id.clone(), &expr_set).unwrap(); - let project_2 = build_project_plan(project, affected_id, &expr_set).unwrap(); + let project = build_project_plan(join, affected_id.clone(), &expr_set_1).unwrap(); + let project_2 = build_project_plan(project, affected_id, &expr_set_2).unwrap(); let mut field_set = BTreeSet::new(); for field in project_2.schema().fields() { @@ -858,4 +885,28 @@ mod test { ]"###; assert_eq!(expected, formatted_fields_with_datatype); } + + #[test] + fn cross_plans_subexpr_() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(binary_expr( + binary_expr(lit(1), Operator::Gt, col("a")), + Operator::And, + binary_expr(lit(1), Operator::Gt, col("a")), + ))? + .build()?; + + let expected = "Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\ + \n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\ + \n TableScan: test"; + + let output_schema = plan.schema(); + println!("output schema: {:?}", output_schema); + + assert_optimized_plan_eq(expected, &plan); + + Ok(()) + } } From b4ac7c98dc28bec7ab276dd7e3fc7b94db2dba60 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Mon, 26 Dec 2022 19:28:03 +0800 Subject: [PATCH 2/7] add another projection to change schema back Signed-off-by: Ruihang Xia --- .../optimizer/src/common_subexpr_eliminate.rs | 46 +++++++++++++------ 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 59f00bd8187fe..d92a0da1704f2 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::DataType; -use datafusion_common::{DFField, DFSchemaRef, DataFusionError, Result}; +use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::{ col, expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}, @@ -131,7 +131,7 @@ impl OptimizerRule for CommonSubexprEliminate { predicate, &mut expr_set, &mut id_array, - input_schema, + input_schema.clone(), )?; let (mut new_expr, new_input) = self.rewrite_expr( @@ -143,10 +143,15 @@ impl OptimizerRule for CommonSubexprEliminate { )?; if let Some(predicate) = pop_expr(&mut new_expr)?.pop() { - Ok(Some(LogicalPlan::Filter(Filter::try_new( + // Ok(Some(LogicalPlan::Filter(Filter::try_new( + // predicate, + // Arc::new(new_input), + // )?))) + let filter = Arc::new(LogicalPlan::Filter(Filter::try_new( predicate, Arc::new(new_input), - )?))) + )?)); + Ok(Some(build_recover_project_plan(&input_schema, filter))) } else { Err(DataFusionError::Internal( "Failed to pop predicate expr".to_string(), @@ -209,16 +214,17 @@ impl OptimizerRule for CommonSubexprEliminate { } LogicalPlan::Sort(Sort { expr, input, fetch }) => { let input_schema = Arc::clone(input.schema()); - let arrays = to_arrays(expr, input_schema, &mut expr_set)?; + let arrays = to_arrays(expr, input_schema.clone(), &mut expr_set)?; let (mut new_expr, new_input) = self.rewrite_expr(&[expr], &[&arrays], input, &mut expr_set, config)?; - Ok(Some(LogicalPlan::Sort(Sort { + let sort = Arc::new(LogicalPlan::Sort(Sort { expr: pop_expr(&mut new_expr)?, input: Arc::new(new_input), fetch: *fetch, - }))) + })); + Ok(Some(build_recover_project_plan(&input_schema, sort))) } LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_) @@ -325,6 +331,18 @@ fn build_project_plan( )?)) } +fn build_recover_project_plan(schema: &DFSchema, input: Arc) -> LogicalPlan { + let col_exprs = schema + .fields() + .iter() + .map(|field| Expr::Column(field.qualified_column())) + .collect(); + LogicalPlan::Projection( + Projection::try_new(col_exprs, input) + .expect("Cannot build projection plan from an invalid schema"), + ) +} + /// Go through an expression tree and generate identifier. /// /// An identifier contains information of the expression itself and its sub-expression. @@ -866,10 +884,6 @@ mod test { .collect(); let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}"); let expected = r###"[ - ( - "CAST(table.a AS Int64)table.a", - Int64, - ), ( "a", UInt64, @@ -898,9 +912,13 @@ mod test { ))? .build()?; - let expected = "Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\ - \n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\ - \n TableScan: test"; + // let expected = "Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\ + // \n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\ + // \n TableScan: test"; + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\ + \n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c + \n TableScan: test"; let output_schema = plan.schema(); println!("output schema: {:?}", output_schema); From 8791f63dace11093a62f9c871d9c5be15c33e358 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Mon, 26 Dec 2022 19:51:41 +0800 Subject: [PATCH 3/7] conditional recover and add document Signed-off-by: Ruihang Xia --- .../optimizer/src/common_subexpr_eliminate.rs | 50 ++++++++++++------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index d92a0da1704f2..4234b605c3358 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -86,7 +86,7 @@ impl CommonSubexprEliminate { .try_optimize(input, config)? .unwrap_or_else(|| input.clone()); if !affected_id.is_empty() { - new_input = build_project_plan(new_input, affected_id, expr_set)?; + new_input = build_common_expr_project_plan(new_input, affected_id, expr_set)?; } Ok((rewrite_exprs, new_input)) @@ -143,15 +143,15 @@ impl OptimizerRule for CommonSubexprEliminate { )?; if let Some(predicate) = pop_expr(&mut new_expr)?.pop() { - // Ok(Some(LogicalPlan::Filter(Filter::try_new( - // predicate, - // Arc::new(new_input), - // )?))) - let filter = Arc::new(LogicalPlan::Filter(Filter::try_new( + let filter = LogicalPlan::Filter(Filter::try_new( predicate, Arc::new(new_input), - )?)); - Ok(Some(build_recover_project_plan(&input_schema, filter))) + )?); + if filter.schema() == &input_schema { + Ok(Some(filter)) + } else { + Ok(Some(build_recover_project_plan(&input_schema, filter))) + } } else { Err(DataFusionError::Internal( "Failed to pop predicate expr".to_string(), @@ -219,12 +219,16 @@ impl OptimizerRule for CommonSubexprEliminate { let (mut new_expr, new_input) = self.rewrite_expr(&[expr], &[&arrays], input, &mut expr_set, config)?; - let sort = Arc::new(LogicalPlan::Sort(Sort { + let sort = LogicalPlan::Sort(Sort { expr: pop_expr(&mut new_expr)?, input: Arc::new(new_input), fetch: *fetch, - })); - Ok(Some(build_recover_project_plan(&input_schema, sort))) + }); + if sort.schema() == &input_schema { + Ok(Some(sort)) + } else { + Ok(Some(build_recover_project_plan(&input_schema, sort))) + } } LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_) @@ -295,7 +299,7 @@ fn to_arrays( } /// Build the "intermediate" projection plan that evaluates the extracted common expressions. -fn build_project_plan( +fn build_common_expr_project_plan( input: LogicalPlan, affected_id: BTreeSet, expr_set: &ExprSet, @@ -331,14 +335,18 @@ fn build_project_plan( )?)) } -fn build_recover_project_plan(schema: &DFSchema, input: Arc) -> LogicalPlan { +/// Build the projection plan to eliminate unexpected columns produced by +/// the "intermediate" projection plan built in [build_common_expr_project_plan]. +/// +/// This is for those plans who don't keep its own output schema like `Filter` or `Sort`. +fn build_recover_project_plan(schema: &DFSchema, input: LogicalPlan) -> LogicalPlan { let col_exprs = schema .fields() .iter() .map(|field| Expr::Column(field.qualified_column())) .collect(); LogicalPlan::Projection( - Projection::try_new(col_exprs, input) + Projection::try_new(col_exprs, Arc::new(input)) .expect("Cannot build projection plan from an invalid schema"), ) } @@ -767,7 +775,6 @@ mod test { \n TableScan: test"; assert_optimized_plan_eq(expected, &plan); - Ok(()) } @@ -795,8 +802,10 @@ mod test { .into_iter() .collect(); let project = - build_project_plan(table_scan, affected_id.clone(), &expr_set_1).unwrap(); - let project_2 = build_project_plan(project, affected_id, &expr_set_2).unwrap(); + build_common_expr_project_plan(table_scan, affected_id.clone(), &expr_set_1) + .unwrap(); + let project_2 = + build_common_expr_project_plan(project, affected_id, &expr_set_2).unwrap(); let mut field_set = BTreeSet::new(); for field in project_2.schema().fields() { @@ -841,8 +850,11 @@ mod test { ] .into_iter() .collect(); - let project = build_project_plan(join, affected_id.clone(), &expr_set_1).unwrap(); - let project_2 = build_project_plan(project, affected_id, &expr_set_2).unwrap(); + let project = + build_common_expr_project_plan(join, affected_id.clone(), &expr_set_1) + .unwrap(); + let project_2 = + build_common_expr_project_plan(project, affected_id, &expr_set_2).unwrap(); let mut field_set = BTreeSet::new(); for field in project_2.schema().fields() { From cdcb5e8160a3999d54809cb90c2fef5754c0e0bd Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Mon, 26 Dec 2022 19:55:29 +0800 Subject: [PATCH 4/7] clean up Signed-off-by: Ruihang Xia --- datafusion/expr/src/expr_schema.rs | 21 ++++++------------- datafusion/expr/src/logical_plan/plan.rs | 1 - .../optimizer/src/common_subexpr_eliminate.rs | 10 ++------- 3 files changed, 8 insertions(+), 24 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 8ccc884fc4aac..c1a625cf47c22 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -235,7 +235,6 @@ impl ExprSchemable for Expr { /// Returns a [arrow::datatypes::Field] compatible with this expression. fn to_field(&self, input_schema: &DFSchema) -> Result { - println!("to_field: {:?}", self); match self { Expr::Column(c) => Ok(DFField::new( c.relation.as_deref(), @@ -243,20 +242,12 @@ impl ExprSchemable for Expr { self.get_type(input_schema)?, self.nullable(input_schema)?, )), - _ => { - let name = &self.display_name()?; - println!("name: {:?}", name); - let data_type = self.get_type(input_schema)?; - println!("data type: {:?}", data_type); - let nullable = self.nullable(input_schema)?; - println!("nullable: {:?}", nullable); - Ok(DFField::new( - None, - &self.display_name()?, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - )) - } + _ => Ok(DFField::new( + None, + &self.display_name()?, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + )), } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6365022d918e3..23f26ad8d1f5e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1256,7 +1256,6 @@ impl Projection { exprlist_to_fields(&expr, &input)?, input.schema().metadata().clone(), )?); - println!("generated schema: {schema:?}"); Self::try_new_with_schema(expr, input, schema) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 4234b605c3358..887d2c6d409d1 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -913,7 +913,7 @@ mod test { } #[test] - fn cross_plans_subexpr_() -> Result<()> { + fn filter_schema_changed() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) @@ -924,17 +924,11 @@ mod test { ))? .build()?; - // let expected = "Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\ - // \n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\ - // \n TableScan: test"; let expected = "Projection: test.a, test.b, test.c\ \n Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\ - \n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c + \n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\ \n TableScan: test"; - let output_schema = plan.schema(); - println!("output schema: {:?}", output_schema); - assert_optimized_plan_eq(expected, &plan); Ok(()) From 162838c2b34278a0c4f63597b8713ab5e94249e7 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 27 Dec 2022 11:10:46 +0800 Subject: [PATCH 5/7] check schema after all Signed-off-by: Ruihang Xia --- .../optimizer/src/common_subexpr_eliminate.rs | 69 ++++++++----------- 1 file changed, 30 insertions(+), 39 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 887d2c6d409d1..4ad2cb84ee308 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -101,7 +101,8 @@ impl OptimizerRule for CommonSubexprEliminate { ) -> Result> { let mut expr_set = ExprSet::new(); - match plan { + let original_schema = plan.schema().clone(); + let mut optimized_plan = match plan { LogicalPlan::Projection(Projection { expr, input, @@ -114,13 +115,11 @@ impl OptimizerRule for CommonSubexprEliminate { let (mut new_expr, new_input) = self.rewrite_expr(&[expr], &[&arrays], input, &mut expr_set, config)?; - Ok(Some(LogicalPlan::Projection( - Projection::try_new_with_schema( - pop_expr(&mut new_expr)?, - Arc::new(new_input), - schema.clone(), - )?, - ))) + LogicalPlan::Projection(Projection::try_new_with_schema( + pop_expr(&mut new_expr)?, + Arc::new(new_input), + schema.clone(), + )?) } LogicalPlan::Filter(filter) => { let input = &filter.input; @@ -131,7 +130,7 @@ impl OptimizerRule for CommonSubexprEliminate { predicate, &mut expr_set, &mut id_array, - input_schema.clone(), + input_schema, )?; let (mut new_expr, new_input) = self.rewrite_expr( @@ -143,19 +142,11 @@ impl OptimizerRule for CommonSubexprEliminate { )?; if let Some(predicate) = pop_expr(&mut new_expr)?.pop() { - let filter = LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_input), - )?); - if filter.schema() == &input_schema { - Ok(Some(filter)) - } else { - Ok(Some(build_recover_project_plan(&input_schema, filter))) - } + LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(new_input))?) } else { - Err(DataFusionError::Internal( + return Err(DataFusionError::Internal( "Failed to pop predicate expr".to_string(), - )) + )); } } LogicalPlan::Window(Window { @@ -174,11 +165,11 @@ impl OptimizerRule for CommonSubexprEliminate { config, )?; - Ok(Some(LogicalPlan::Window(Window { + LogicalPlan::Window(Window { input: Arc::new(new_input), window_expr: pop_expr(&mut new_expr)?, schema: schema.clone(), - }))) + }) } LogicalPlan::Aggregate(Aggregate { group_expr, @@ -203,32 +194,25 @@ impl OptimizerRule for CommonSubexprEliminate { let new_aggr_expr = pop_expr(&mut new_expr)?; let new_group_expr = pop_expr(&mut new_expr)?; - Ok(Some(LogicalPlan::Aggregate( - Aggregate::try_new_with_schema( - Arc::new(new_input), - new_group_expr, - new_aggr_expr, - schema.clone(), - )?, - ))) + LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + Arc::new(new_input), + new_group_expr, + new_aggr_expr, + schema.clone(), + )?) } LogicalPlan::Sort(Sort { expr, input, fetch }) => { let input_schema = Arc::clone(input.schema()); - let arrays = to_arrays(expr, input_schema.clone(), &mut expr_set)?; + let arrays = to_arrays(expr, input_schema, &mut expr_set)?; let (mut new_expr, new_input) = self.rewrite_expr(&[expr], &[&arrays], input, &mut expr_set, config)?; - let sort = LogicalPlan::Sort(Sort { + LogicalPlan::Sort(Sort { expr: pop_expr(&mut new_expr)?, input: Arc::new(new_input), fetch: *fetch, - }); - if sort.schema() == &input_schema { - Ok(Some(sort)) - } else { - Ok(Some(build_recover_project_plan(&input_schema, sort))) - } + }) } LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_) @@ -254,9 +238,16 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Extension(_) | LogicalPlan::Prepare(_) => { // apply the optimization to all inputs of the plan - Ok(Some(utils::optimize_children(self, plan, config)?)) + utils::optimize_children(self, plan, config)? } + }; + + // add an additional projection if the output schema changed. + if optimized_plan.schema() != &original_schema { + optimized_plan = build_recover_project_plan(&original_schema, optimized_plan); } + + Ok(Some(optimized_plan)) } fn name(&self) -> &str { From 1c225836dd39de44405816c5c945616405c56003 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 27 Dec 2022 11:12:06 +0800 Subject: [PATCH 6/7] Update datafusion/optimizer/src/common_subexpr_eliminate.rs Co-authored-by: Andrew Lamb --- datafusion/optimizer/src/common_subexpr_eliminate.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 4ad2cb84ee308..c5424dbfd17ba 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -908,10 +908,8 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(binary_expr( - binary_expr(lit(1), Operator::Gt, col("a")), - Operator::And, - binary_expr(lit(1), Operator::Gt, col("a")), + .filter(lit(1).gt(col("a"))).and( + lit(1).gt(col("a"))), ))? .build()?; From 7f51892952300e702d672a83dd3674122dc69edb Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 27 Dec 2022 11:13:55 +0800 Subject: [PATCH 7/7] fix format Signed-off-by: Ruihang Xia --- datafusion/optimizer/src/common_subexpr_eliminate.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index c5424dbfd17ba..c8bddcfbfacd0 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -908,9 +908,7 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(lit(1).gt(col("a"))).and( - lit(1).gt(col("a"))), - ))? + .filter(lit(1).gt(col("a")).and(lit(1).gt(col("a"))))? .build()?; let expected = "Projection: test.a, test.b, test.c\