diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java index 3590d07483f0b9..d438db12ea437a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java @@ -167,18 +167,26 @@ public List buildRules() { LogicalAggregate agg = ctx.root; List output = bind(agg.getOutputExpressions(), agg.children(), agg, ctx.cascadesContext); + + // The columns referenced in group by are first obtained from the child's output, + // and then from the node's output + Map childOutputsToExpr = agg.child().getOutput().stream() + .collect(Collectors.toMap(Slot::getName, Slot::toSlot, (oldExpr, newExpr) -> oldExpr)); Map aliasNameToExpr = output.stream() .filter(ne -> ne instanceof Alias) .map(Alias.class::cast) .collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr)); + aliasNameToExpr.entrySet().stream() + .forEach(e -> childOutputsToExpr.putIfAbsent(e.getKey(), e.getValue())); + List replacedGroupBy = agg.getGroupByExpressions().stream() .map(groupBy -> { if (groupBy instanceof UnboundSlot) { UnboundSlot unboundSlot = (UnboundSlot) groupBy; if (unboundSlot.getNameParts().size() == 1) { String name = unboundSlot.getNameParts().get(0); - if (aliasNameToExpr.containsKey(name)) { - return aliasNameToExpr.get(name); + if (childOutputsToExpr.containsKey(name)) { + return childOutputsToExpr.get(name); } } } @@ -197,10 +205,17 @@ public List buildRules() { List output = bind(repeat.getOutputExpressions(), repeat.children(), repeat, ctx.cascadesContext); + // The columns referenced in group by are first obtained from the child's output, + // and then from the node's output + Map childOutputsToExpr = repeat.child().getOutput().stream() + .collect(Collectors.toMap(Slot::getName, Slot::toSlot, (oldExpr, newExpr) -> oldExpr)); Map aliasNameToExpr = output.stream() .filter(ne -> ne instanceof Alias) .map(Alias.class::cast) .collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr)); + aliasNameToExpr.entrySet().stream() + .forEach(e -> childOutputsToExpr.putIfAbsent(e.getKey(), e.getValue())); + List> replacedGroupingSets = repeat.getGroupingSets().stream() .map(groupBy -> groupBy.stream().map(expr -> { @@ -208,8 +223,8 @@ public List buildRules() { UnboundSlot unboundSlot = (UnboundSlot) expr; if (unboundSlot.getNameParts().size() == 1) { String name = unboundSlot.getNameParts().get(0); - if (aliasNameToExpr.containsKey(name)) { - return aliasNameToExpr.get(name); + if (childOutputsToExpr.containsKey(name)) { + return childOutputsToExpr.get(name); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 8c62aab3d4a300..7818fe2bb78208 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -200,7 +200,7 @@ private Set collectNeedToSlotExpressions(LogicalRepeat repeat) } private Plan pushDownProject(Set pushedExprs, Plan originBottomPlan) { - if (!pushedExprs.equals(originBottomPlan.getOutputSet())) { + if (!pushedExprs.equals(originBottomPlan.getOutputSet()) && !pushedExprs.isEmpty()) { return new LogicalProject<>(ImmutableList.copyOf(pushedExprs), originBottomPlan); } return originBottomPlan; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java index d69cff4d58e18c..b5a1ddc32b7594 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java @@ -93,7 +93,8 @@ public Rule build() { aggregateFunctionToSlotContext.pushDownToNamedExpression(normalizedAggregateFunctions); List normalizedGroupBy = - (List) groupByAndArgumentToSlotContext.normalizeToUseSlotRef(aggregate.getGroupByExpressions()); + (List) groupByAndArgumentToSlotContext + .normalizeToUseSlotRef(aggregate.getGroupByExpressions()); // we can safely add all groupBy and aggregate functions to output, because we will // add a project on it, and the upper project can protect the scope of visible of slot diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/GroupingSetsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/GroupingSetsTest.java index cec372311eb773..a30364b1162fd6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/GroupingSetsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/GroupingSetsTest.java @@ -183,4 +183,53 @@ public void test() { PlanChecker.from(connectContext) .checkPlannerResult("select if(k1 = 1, 2, k1) k_if from t1"); } + + @Test + public void test1() { + PlanChecker.from(connectContext) + .checkPlannerResult("select coalesce(col1, 'all') as col1, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test1_1() { + PlanChecker.from(connectContext) + .checkPlannerResult("select coalesce(col1, 'all') as col2, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test1_2() { + PlanChecker.from(connectContext) + .checkPlannerResult("select coalesce(col1, 'all') as col2, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),());"); + } + + @Test + public void test2() { + PlanChecker.from(connectContext) + .checkPlannerResult("select if(1 = null, 'all', 2) as col1, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test2_1() { + PlanChecker.from(connectContext) + .checkPlannerResult("select if(col1 = null, 'all', 2) as col1, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test2_2() { + PlanChecker.from(connectContext) + .checkPlannerResult("select if(col1 = null, 'all', 2) as col2, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test2_3() { + PlanChecker.from(connectContext) + .checkPlannerResult("select if(col1 = null, 'all', 2) as col2, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),());"); + } } diff --git a/regression-test/data/nereids_syntax_p0/grouping_sets.out b/regression-test/data/nereids_syntax_p0/grouping_sets.out index 38b19938d03389..6c18f54e2b5198 100644 --- a/regression-test/data/nereids_syntax_p0/grouping_sets.out +++ b/regression-test/data/nereids_syntax_p0/grouping_sets.out @@ -224,3 +224,37 @@ 3 4 +-- !select1 -- +a 1 +all 1 +all 2 + +-- !select2 -- +a 1 +all 1 +all 2 + +-- !select3 -- +\N 2 +a 1 +all 1 + +-- !select4 -- +2 1 +2 1 +2 2 + +-- !select5 -- +2 1 +2 1 +2 2 + +-- !select6 -- +2 1 +2 1 +2 2 + +-- !select7 -- +2 1 +2 1 +2 2 diff --git a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy index 5218a4215c181f..8e6cc6e5c74fe3 100644 --- a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy +++ b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy @@ -232,4 +232,32 @@ suite("test_nereids_grouping_sets") { ) T ) T2; """ + + order_qt_select1 """ + select coalesce(col1, 'all') as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select2 """ + select coalesce(col1, 'all') as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select3 """ + select coalesce(col1, 'all') as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),()); + """ + + order_qt_select4 """ + select if(1 = null, 'all', 2) as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select5 """ + select if(col1 = null, 'all', 2) as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select6 """ + select if(1 = null, 'all', 2) as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select7 """ + select if(col1 = null, 'all', 2) as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ }