diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java index 7daeef2401988a..4fc02581ca4c77 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.Aggregate; import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; @@ -42,6 +43,7 @@ import org.apache.doris.nereids.trees.plans.logical.OutputPrunable; import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; import org.apache.doris.qe.ConnectContext; @@ -345,6 +347,8 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context) } List prunedOutputs = Lists.newArrayList(); List> constantExprsList = union.getConstantExprsList(); + List> regularChildrenOutputs = union.getRegularChildrenOutputs(); + List children = union.children(); List extractColumnIndex = Lists.newArrayList(); for (int i = 0; i < originOutput.size(); i++) { NamedExpression output = originOutput.get(i); @@ -353,31 +357,41 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context) extractColumnIndex.add(i); } } - if (prunedOutputs.isEmpty()) { - List candidates = Lists.newArrayList(originOutput); - candidates.retainAll(keys); - if (candidates.isEmpty()) { - candidates = originOutput; - } - NamedExpression minimumColumn = ExpressionUtils.selectMinimumColumn(candidates); - prunedOutputs = ImmutableList.of(minimumColumn); - extractColumnIndex.add(originOutput.indexOf(minimumColumn)); - } - int len = extractColumnIndex.size(); ImmutableList.Builder> prunedConstantExprsList = ImmutableList.builderWithExpectedSize(constantExprsList.size()); - for (List row : constantExprsList) { - ImmutableList.Builder newRow = ImmutableList.builderWithExpectedSize(len); - for (int idx : extractColumnIndex) { - newRow.add(row.get(idx)); + if (prunedOutputs.isEmpty()) { + // process prune all columns + NamedExpression originSlot = originOutput.get(0); + prunedOutputs = ImmutableList.of(new SlotReference(originSlot.getExprId(), originSlot.getName(), + TinyIntType.INSTANCE, false, originSlot.getQualifier())); + regularChildrenOutputs = Lists.newArrayListWithCapacity(regularChildrenOutputs.size()); + children = Lists.newArrayListWithCapacity(children.size()); + for (int i = 0; i < union.getArity(); i++) { + LogicalProject project = new LogicalProject<>( + ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1))), union.child(i)); + regularChildrenOutputs.add((List) project.getOutput()); + children.add(project); + } + for (int i = 0; i < constantExprsList.size(); i++) { + prunedConstantExprsList.add(ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1)))); + } + } else { + int len = extractColumnIndex.size(); + for (List row : constantExprsList) { + ImmutableList.Builder newRow = ImmutableList.builderWithExpectedSize(len); + for (int idx : extractColumnIndex) { + newRow.add(row.get(idx)); + } + prunedConstantExprsList.add(newRow.build()); } - prunedConstantExprsList.add(newRow.build()); } - if (prunedOutputs.equals(originOutput)) { + + if (prunedOutputs.equals(originOutput) && !context.requiredSlots.isEmpty()) { return union; } else { - return union.withNewOutputsAndConstExprsList(prunedOutputs, prunedConstantExprsList.build()); + return union.withNewOutputsChildrenAndConstExprsList(prunedOutputs, children, + regularChildrenOutputs, prunedConstantExprsList.build()); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java index 9f18eeb851fee6..1ae7219fbb7eb6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java @@ -19,6 +19,7 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.types.DoubleType; @@ -313,6 +314,21 @@ public void pruneAggregateOutput() { ); } + @Test + public void pruneUnionAllWithCount() { + PlanChecker.from(connectContext) + .analyze("select count() from (select 1, 2 union all select id, age from student) t") + .customRewrite(new ColumnPruning()) + .matches( + logicalProject( + logicalUnion( + logicalProject().when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral), + logicalProject().when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral) + ) + ).when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral) + ); + } + private List getOutputQualifiedNames(LogicalProject p) { return getOutputQualifiedNames(p.getOutputs()); } diff --git a/regression-test/suites/nereids_rules_p0/column_pruning/union_const_expr_column_pruning.groovy b/regression-test/suites/nereids_rules_p0/column_pruning/union_const_expr_column_pruning.groovy index 77d62a2189960b..cca668947e0d76 100644 --- a/regression-test/suites/nereids_rules_p0/column_pruning/union_const_expr_column_pruning.groovy +++ b/regression-test/suites/nereids_rules_p0/column_pruning/union_const_expr_column_pruning.groovy @@ -18,6 +18,7 @@ suite("const_expr_column_pruning") { sql """SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'""" // should only keep one column in union - sql "select count(1) from(select 3, 6 union all select 1, 3) t" - sql "select count(a) from(select 3 a, 6 union all select 1, 3) t" -} \ No newline at end of file + sql """select count(1) from(select 3, 6 union all select 1, 3) t""" + sql """select count(1) from(select 3, 6 union all select "1", 3) t""" + sql """select count(a) from(select 3 a, 6 union all select "1", 3) t""" +} diff --git a/regression-test/suites/nereids_rules_p0/column_pruning/window_column_pruning.groovy b/regression-test/suites/nereids_rules_p0/column_pruning/window_column_pruning.groovy index a83f8ed75280e5..dfe7a78f1659f3 100644 --- a/regression-test/suites/nereids_rules_p0/column_pruning/window_column_pruning.groovy +++ b/regression-test/suites/nereids_rules_p0/column_pruning/window_column_pruning.groovy @@ -56,5 +56,10 @@ suite("window_column_pruning") { sql "select id from (select id, rank() over() px from window_column_pruning union all select id, rank() over() px from window_column_pruning) a" notContains "rank" } + + explain { + sql "select count() from (select row_number() over(partition by id) from window_column_pruning) tmp" + notContains "row_number" + } }