From e6afa52849f80f972fdd2f45dfb058754d548a93 Mon Sep 17 00:00:00 2001 From: jianghaochen Date: Fri, 23 Dec 2022 10:48:39 +0800 Subject: [PATCH 1/2] [Fix](Nereids)fix scalarFunction and groupingSets --- .../rules/analysis/NormalizeRepeat.java | 14 +++-- .../rewrite/logical/NormalizeAggregate.java | 15 ++--- .../rewrite/logical/NormalizeToSlot.java | 60 ++++++++++++++++--- .../nereids/trees/plans/GroupingSetsTest.java | 49 +++++++++++++++ .../data/nereids_syntax_p0/grouping_sets.out | 34 +++++++++++ .../nereids_syntax_p0/grouping_sets.groovy | 28 +++++++++ 6 files changed, 180 insertions(+), 20 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 8c62aab3d4a300..2d7d3394ecf26e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -119,13 +119,14 @@ private LogicalAggregate normalizeRepeat(LogicalRepeat repeat) { // normalize grouping sets to List> List> normalizedGroupingSets = repeat.getGroupingSets() .stream() - .map(groupingSet -> (List) (List) context.normalizeToUseSlotRef(groupingSet)) + .map(groupingSet -> (List) (List) context.normalizeToUseSlotRef( + groupingSet, false)) .collect(ImmutableList.toImmutableList()); // replace the arguments of grouping scalar function to virtual slots // replace some complex expression to slot, e.g. `a + 1` List normalizedAggOutput = context.normalizeToUseSlotRef( - repeat.getOutputExpressions(), this::normalizeGroupingScalarFunction); + repeat.getOutputExpressions(), this::normalizeGroupingScalarFunction, true); Set virtualSlotsInFunction = ExpressionUtils.collect(normalizedAggOutput, VirtualSlotReference.class::isInstance); @@ -152,7 +153,7 @@ private LogicalAggregate normalizeRepeat(LogicalRepeat repeat) { .addAll(allVirtualSlots) .build(); - Set pushedProject = context.pushDownToNamedExpression(needToSlots); + Set pushedProject = context.pushDownToNamedExpression(needToSlots, repeat); Plan normalizedChild = pushDownProject(pushedProject, repeat.child()); LogicalRepeat normalizedRepeat = repeat.withNormalizedExpr( @@ -200,7 +201,7 @@ private Set collectNeedToSlotExpressions(LogicalRepeat repeat) } private Plan pushDownProject(Set pushedExprs, Plan originBottomPlan) { - if (!pushedExprs.equals(originBottomPlan.getOutputSet())) { + if (!pushedExprs.equals(originBottomPlan.getOutputSet()) && !pushedExprs.isEmpty()) { return new LogicalProject<>(ImmutableList.copyOf(pushedExprs), originBottomPlan); } return originBottomPlan; @@ -241,7 +242,7 @@ public NormalizeToSlotContext buildContext(Repeat repeat, normalizeToSlotMap.put(expression, pushDownTriplet.get()); } } - return new NormalizeToSlotContext(normalizeToSlotMap); + return new NormalizeToSlotContext(normalizeToSlotMap, repeat); } private Optional toGroupingSetExpressionPushDownTriplet( @@ -255,7 +256,8 @@ private Optional toGroupingSetExpressionPushDownTriplet( private Expression normalizeGroupingScalarFunction(NormalizeToSlotContext context, Expression expr) { if (expr instanceof GroupingScalarFunction) { GroupingScalarFunction function = (GroupingScalarFunction) expr; - List normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments()); + List normalizedRealExpressions = context.normalizeToUseSlotRef( + function.getArguments(), false); function = function.withChildren(normalizedRealExpressions); // eliminate GroupingScalarFunction and replace to VirtualSlotReference return Repeat.generateVirtualSlotByFunction(function); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java index d69cff4d58e18c..85a5bdbd499269 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java @@ -65,9 +65,9 @@ public Rule build() { aggregate.getOutputExpressions(), Alias.class::isInstance); Set needToSlots = collectGroupByAndArgumentsOfAggregateFunctions(aggregate); NormalizeToSlotContext groupByAndArgumentToSlotContext = - NormalizeToSlotContext.buildContext(existsAliases, needToSlots); + NormalizeToSlotContext.buildContext(existsAliases, needToSlots, aggregate); Set bottomProjects = - groupByAndArgumentToSlotContext.pushDownToNamedExpression(needToSlots); + groupByAndArgumentToSlotContext.pushDownToNamedExpression(needToSlots, aggregate); Plan normalizedChild = bottomProjects.isEmpty() ? aggregate.child() : new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child()); @@ -78,7 +78,7 @@ public Rule build() { // some expression on the aggregate functions, e.g. `sum(value) + 1`, we should replace // the sum(value) to slot and move the `slot + 1` to the upper project later. List normalizeOutputPhase1 = groupByAndArgumentToSlotContext - .normalizeToUseSlotRef(aggregate.getOutputExpressions()); + .normalizeToUseSlotRef(aggregate.getOutputExpressions(), true); Set normalizedAggregateFunctions = ExpressionUtils.collect(normalizeOutputPhase1, AggregateFunction.class::isInstance); @@ -87,13 +87,14 @@ public Rule build() { // now reuse the exists alias for the aggregate functions, // or create new alias for the aggregate functions NormalizeToSlotContext aggregateFunctionToSlotContext = - NormalizeToSlotContext.buildContext(existsAliases, normalizedAggregateFunctions); + NormalizeToSlotContext.buildContext(existsAliases, normalizedAggregateFunctions, aggregate); Set normalizedAggregateFunctionsWithAlias = - aggregateFunctionToSlotContext.pushDownToNamedExpression(normalizedAggregateFunctions); + aggregateFunctionToSlotContext.pushDownToNamedExpression(normalizedAggregateFunctions, aggregate); List normalizedGroupBy = - (List) groupByAndArgumentToSlotContext.normalizeToUseSlotRef(aggregate.getGroupByExpressions()); + (List) groupByAndArgumentToSlotContext + .normalizeToUseSlotRef(aggregate.getGroupByExpressions(), false); // we can safely add all groupBy and aggregate functions to output, because we will // add a project on it, and the upper project can protect the scope of visible of slot @@ -107,7 +108,7 @@ public Rule build() { // replace aggregate function to slot List upperProjects = - aggregateFunctionToSlotContext.normalizeToUseSlotRef(normalizeOutputPhase1); + aggregateFunctionToSlotContext.normalizeToUseSlotRef(normalizeOutputPhase1, true); return new LogicalProject<>(upperProjects, normalizedAggregate); }).toRule(RuleType.NORMALIZE_AGGREGATE); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java index 34686534e39548..d574c5953ca678 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java @@ -21,6 +21,8 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction; +import org.apache.doris.nereids.trees.plans.Plan; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -29,6 +31,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.BiFunction; import javax.annotation.Nullable; @@ -39,14 +42,17 @@ public interface NormalizeToSlot { /** NormalizeSlotContext */ class NormalizeToSlotContext { private final Map normalizeToSlotMap; + private final Plan currentPlan; - public NormalizeToSlotContext(Map normalizeToSlotMap) { + public NormalizeToSlotContext( + Map normalizeToSlotMap, Plan currentPlan) { this.normalizeToSlotMap = normalizeToSlotMap; + this.currentPlan = currentPlan; } /** buildContext */ public static NormalizeToSlotContext buildContext( - Set existsAliases, Set sourceExpressions) { + Set existsAliases, Set sourceExpressions, Plan currentPlan) { Map normalizeToSlotMap = Maps.newLinkedHashMap(); Map existsAliasMap = Maps.newLinkedHashMap(); @@ -62,36 +68,67 @@ public static NormalizeToSlotContext buildContext( NormalizeToSlotTriplet.toTriplet(expression, existsAliasMap.get(expression)); normalizeToSlotMap.put(expression, normalizeToSlotTriplet); } - return new NormalizeToSlotContext(normalizeToSlotMap); + return new NormalizeToSlotContext(normalizeToSlotMap, currentPlan); } /** normalizeToUseSlotRef, no custom normalize */ - public List normalizeToUseSlotRef(List expressions) { - return normalizeToUseSlotRef(expressions, (context, expr) -> expr); + public List normalizeToUseSlotRef( + List expressions, boolean isOutput) { + return normalizeToUseSlotRef(expressions, (context, expr) -> expr, isOutput); } /** normalizeToUseSlotRef */ public List normalizeToUseSlotRef(List expressions, - BiFunction customNormalize) { + BiFunction customNormalize, + boolean isOutput) { return expressions.stream() .map(expr -> (E) expr.rewriteDownShortCircuit(child -> { Expression newChild = customNormalize.apply(this, child); if (newChild != null && newChild != child) { return newChild; } + if (child instanceof ScalarFunction && !isOutput) { + return getSlotFromChildOutputsWhichEqualName(currentPlan, child); + } + if (child instanceof ScalarFunction && isOutput + && collectSlotFromChildOutputsWhichEqualName(currentPlan, child).isPresent()) { + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); + return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.originExpr; + } NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; })).collect(ImmutableList.toImmutableList()); } + private Expression getSlotFromChildOutputsWhichEqualName(Plan repeat, Expression child) { + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); + Optional slot = collectSlotFromChildOutputsWhichEqualName(repeat, child); + if (slot.isPresent()) { + return slot.get(); + } + return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; + } + + private Optional collectSlotFromChildOutputsWhichEqualName( + Plan repeat, Expression child) { + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); + String asName = normalizeToSlotTriplet == null + ? child.toSql() : normalizeToSlotTriplet.remainExpr.getName(); + return repeat.child(0).getOutput().stream() + .filter(s -> s.getName().equals(asName)) + .findFirst(); + } + /** * generate bottom projections with groupByExpressions. * eg: * groupByExpressions: k1#0, k2#1 + 1; * bottom: k1#0, (k2#1 + 1) AS (k2 + 1)#2; */ - public Set pushDownToNamedExpression(Collection needToPushExpressions) { + public Set pushDownToNamedExpression( + Collection needToPushExpressions, Plan current) { return needToPushExpressions.stream() + .filter(e -> filterScalarFunWithSameAliasNameInChildOutput(e, current)) .map(expr -> { NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(expr); return normalizeToSlotTriplet == null @@ -99,6 +136,15 @@ public Set pushDownToNamedExpression(Collection slot = collectSlotFromChildOutputsWhichEqualName(current, expression); + return !slot.isPresent(); + } + return true; + } } /** NormalizeToSlotTriplet */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/GroupingSetsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/GroupingSetsTest.java index cec372311eb773..a30364b1162fd6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/GroupingSetsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/GroupingSetsTest.java @@ -183,4 +183,53 @@ public void test() { PlanChecker.from(connectContext) .checkPlannerResult("select if(k1 = 1, 2, k1) k_if from t1"); } + + @Test + public void test1() { + PlanChecker.from(connectContext) + .checkPlannerResult("select coalesce(col1, 'all') as col1, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test1_1() { + PlanChecker.from(connectContext) + .checkPlannerResult("select coalesce(col1, 'all') as col2, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test1_2() { + PlanChecker.from(connectContext) + .checkPlannerResult("select coalesce(col1, 'all') as col2, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),());"); + } + + @Test + public void test2() { + PlanChecker.from(connectContext) + .checkPlannerResult("select if(1 = null, 'all', 2) as col1, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test2_1() { + PlanChecker.from(connectContext) + .checkPlannerResult("select if(col1 = null, 'all', 2) as col1, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test2_2() { + PlanChecker.from(connectContext) + .checkPlannerResult("select if(col1 = null, 'all', 2) as col2, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test2_3() { + PlanChecker.from(connectContext) + .checkPlannerResult("select if(col1 = null, 'all', 2) as col2, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),());"); + } } diff --git a/regression-test/data/nereids_syntax_p0/grouping_sets.out b/regression-test/data/nereids_syntax_p0/grouping_sets.out index 38b19938d03389..6c18f54e2b5198 100644 --- a/regression-test/data/nereids_syntax_p0/grouping_sets.out +++ b/regression-test/data/nereids_syntax_p0/grouping_sets.out @@ -224,3 +224,37 @@ 3 4 +-- !select1 -- +a 1 +all 1 +all 2 + +-- !select2 -- +a 1 +all 1 +all 2 + +-- !select3 -- +\N 2 +a 1 +all 1 + +-- !select4 -- +2 1 +2 1 +2 2 + +-- !select5 -- +2 1 +2 1 +2 2 + +-- !select6 -- +2 1 +2 1 +2 2 + +-- !select7 -- +2 1 +2 1 +2 2 diff --git a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy index 5218a4215c181f..8e6cc6e5c74fe3 100644 --- a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy +++ b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy @@ -232,4 +232,32 @@ suite("test_nereids_grouping_sets") { ) T ) T2; """ + + order_qt_select1 """ + select coalesce(col1, 'all') as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select2 """ + select coalesce(col1, 'all') as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select3 """ + select coalesce(col1, 'all') as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),()); + """ + + order_qt_select4 """ + select if(1 = null, 'all', 2) as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select5 """ + select if(col1 = null, 'all', 2) as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select6 """ + select if(1 = null, 'all', 2) as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select7 """ + select if(col1 = null, 'all', 2) as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ } From 461f4d4c0635800404d07a1403cecc0301b313c4 Mon Sep 17 00:00:00 2001 From: jianghaochen Date: Tue, 27 Dec 2022 13:57:58 +0800 Subject: [PATCH 2/2] fix --- .../rules/analysis/BindSlotReference.java | 23 +++++-- .../rules/analysis/NormalizeRepeat.java | 12 ++-- .../rewrite/logical/NormalizeAggregate.java | 14 ++--- .../rewrite/logical/NormalizeToSlot.java | 60 +++---------------- 4 files changed, 38 insertions(+), 71 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java index 3590d07483f0b9..d438db12ea437a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java @@ -167,18 +167,26 @@ public List buildRules() { LogicalAggregate agg = ctx.root; List output = bind(agg.getOutputExpressions(), agg.children(), agg, ctx.cascadesContext); + + // The columns referenced in group by are first obtained from the child's output, + // and then from the node's output + Map childOutputsToExpr = agg.child().getOutput().stream() + .collect(Collectors.toMap(Slot::getName, Slot::toSlot, (oldExpr, newExpr) -> oldExpr)); Map aliasNameToExpr = output.stream() .filter(ne -> ne instanceof Alias) .map(Alias.class::cast) .collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr)); + aliasNameToExpr.entrySet().stream() + .forEach(e -> childOutputsToExpr.putIfAbsent(e.getKey(), e.getValue())); + List replacedGroupBy = agg.getGroupByExpressions().stream() .map(groupBy -> { if (groupBy instanceof UnboundSlot) { UnboundSlot unboundSlot = (UnboundSlot) groupBy; if (unboundSlot.getNameParts().size() == 1) { String name = unboundSlot.getNameParts().get(0); - if (aliasNameToExpr.containsKey(name)) { - return aliasNameToExpr.get(name); + if (childOutputsToExpr.containsKey(name)) { + return childOutputsToExpr.get(name); } } } @@ -197,10 +205,17 @@ public List buildRules() { List output = bind(repeat.getOutputExpressions(), repeat.children(), repeat, ctx.cascadesContext); + // The columns referenced in group by are first obtained from the child's output, + // and then from the node's output + Map childOutputsToExpr = repeat.child().getOutput().stream() + .collect(Collectors.toMap(Slot::getName, Slot::toSlot, (oldExpr, newExpr) -> oldExpr)); Map aliasNameToExpr = output.stream() .filter(ne -> ne instanceof Alias) .map(Alias.class::cast) .collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr)); + aliasNameToExpr.entrySet().stream() + .forEach(e -> childOutputsToExpr.putIfAbsent(e.getKey(), e.getValue())); + List> replacedGroupingSets = repeat.getGroupingSets().stream() .map(groupBy -> groupBy.stream().map(expr -> { @@ -208,8 +223,8 @@ public List buildRules() { UnboundSlot unboundSlot = (UnboundSlot) expr; if (unboundSlot.getNameParts().size() == 1) { String name = unboundSlot.getNameParts().get(0); - if (aliasNameToExpr.containsKey(name)) { - return aliasNameToExpr.get(name); + if (childOutputsToExpr.containsKey(name)) { + return childOutputsToExpr.get(name); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 2d7d3394ecf26e..7818fe2bb78208 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -119,14 +119,13 @@ private LogicalAggregate normalizeRepeat(LogicalRepeat repeat) { // normalize grouping sets to List> List> normalizedGroupingSets = repeat.getGroupingSets() .stream() - .map(groupingSet -> (List) (List) context.normalizeToUseSlotRef( - groupingSet, false)) + .map(groupingSet -> (List) (List) context.normalizeToUseSlotRef(groupingSet)) .collect(ImmutableList.toImmutableList()); // replace the arguments of grouping scalar function to virtual slots // replace some complex expression to slot, e.g. `a + 1` List normalizedAggOutput = context.normalizeToUseSlotRef( - repeat.getOutputExpressions(), this::normalizeGroupingScalarFunction, true); + repeat.getOutputExpressions(), this::normalizeGroupingScalarFunction); Set virtualSlotsInFunction = ExpressionUtils.collect(normalizedAggOutput, VirtualSlotReference.class::isInstance); @@ -153,7 +152,7 @@ private LogicalAggregate normalizeRepeat(LogicalRepeat repeat) { .addAll(allVirtualSlots) .build(); - Set pushedProject = context.pushDownToNamedExpression(needToSlots, repeat); + Set pushedProject = context.pushDownToNamedExpression(needToSlots); Plan normalizedChild = pushDownProject(pushedProject, repeat.child()); LogicalRepeat normalizedRepeat = repeat.withNormalizedExpr( @@ -242,7 +241,7 @@ public NormalizeToSlotContext buildContext(Repeat repeat, normalizeToSlotMap.put(expression, pushDownTriplet.get()); } } - return new NormalizeToSlotContext(normalizeToSlotMap, repeat); + return new NormalizeToSlotContext(normalizeToSlotMap); } private Optional toGroupingSetExpressionPushDownTriplet( @@ -256,8 +255,7 @@ private Optional toGroupingSetExpressionPushDownTriplet( private Expression normalizeGroupingScalarFunction(NormalizeToSlotContext context, Expression expr) { if (expr instanceof GroupingScalarFunction) { GroupingScalarFunction function = (GroupingScalarFunction) expr; - List normalizedRealExpressions = context.normalizeToUseSlotRef( - function.getArguments(), false); + List normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments()); function = function.withChildren(normalizedRealExpressions); // eliminate GroupingScalarFunction and replace to VirtualSlotReference return Repeat.generateVirtualSlotByFunction(function); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java index 85a5bdbd499269..b5a1ddc32b7594 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java @@ -65,9 +65,9 @@ public Rule build() { aggregate.getOutputExpressions(), Alias.class::isInstance); Set needToSlots = collectGroupByAndArgumentsOfAggregateFunctions(aggregate); NormalizeToSlotContext groupByAndArgumentToSlotContext = - NormalizeToSlotContext.buildContext(existsAliases, needToSlots, aggregate); + NormalizeToSlotContext.buildContext(existsAliases, needToSlots); Set bottomProjects = - groupByAndArgumentToSlotContext.pushDownToNamedExpression(needToSlots, aggregate); + groupByAndArgumentToSlotContext.pushDownToNamedExpression(needToSlots); Plan normalizedChild = bottomProjects.isEmpty() ? aggregate.child() : new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child()); @@ -78,7 +78,7 @@ public Rule build() { // some expression on the aggregate functions, e.g. `sum(value) + 1`, we should replace // the sum(value) to slot and move the `slot + 1` to the upper project later. List normalizeOutputPhase1 = groupByAndArgumentToSlotContext - .normalizeToUseSlotRef(aggregate.getOutputExpressions(), true); + .normalizeToUseSlotRef(aggregate.getOutputExpressions()); Set normalizedAggregateFunctions = ExpressionUtils.collect(normalizeOutputPhase1, AggregateFunction.class::isInstance); @@ -87,14 +87,14 @@ public Rule build() { // now reuse the exists alias for the aggregate functions, // or create new alias for the aggregate functions NormalizeToSlotContext aggregateFunctionToSlotContext = - NormalizeToSlotContext.buildContext(existsAliases, normalizedAggregateFunctions, aggregate); + NormalizeToSlotContext.buildContext(existsAliases, normalizedAggregateFunctions); Set normalizedAggregateFunctionsWithAlias = - aggregateFunctionToSlotContext.pushDownToNamedExpression(normalizedAggregateFunctions, aggregate); + aggregateFunctionToSlotContext.pushDownToNamedExpression(normalizedAggregateFunctions); List normalizedGroupBy = (List) groupByAndArgumentToSlotContext - .normalizeToUseSlotRef(aggregate.getGroupByExpressions(), false); + .normalizeToUseSlotRef(aggregate.getGroupByExpressions()); // we can safely add all groupBy and aggregate functions to output, because we will // add a project on it, and the upper project can protect the scope of visible of slot @@ -108,7 +108,7 @@ public Rule build() { // replace aggregate function to slot List upperProjects = - aggregateFunctionToSlotContext.normalizeToUseSlotRef(normalizeOutputPhase1, true); + aggregateFunctionToSlotContext.normalizeToUseSlotRef(normalizeOutputPhase1); return new LogicalProject<>(upperProjects, normalizedAggregate); }).toRule(RuleType.NORMALIZE_AGGREGATE); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java index d574c5953ca678..34686534e39548 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java @@ -21,8 +21,6 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction; -import org.apache.doris.nereids.trees.plans.Plan; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -31,7 +29,6 @@ import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.function.BiFunction; import javax.annotation.Nullable; @@ -42,17 +39,14 @@ public interface NormalizeToSlot { /** NormalizeSlotContext */ class NormalizeToSlotContext { private final Map normalizeToSlotMap; - private final Plan currentPlan; - public NormalizeToSlotContext( - Map normalizeToSlotMap, Plan currentPlan) { + public NormalizeToSlotContext(Map normalizeToSlotMap) { this.normalizeToSlotMap = normalizeToSlotMap; - this.currentPlan = currentPlan; } /** buildContext */ public static NormalizeToSlotContext buildContext( - Set existsAliases, Set sourceExpressions, Plan currentPlan) { + Set existsAliases, Set sourceExpressions) { Map normalizeToSlotMap = Maps.newLinkedHashMap(); Map existsAliasMap = Maps.newLinkedHashMap(); @@ -68,67 +62,36 @@ public static NormalizeToSlotContext buildContext( NormalizeToSlotTriplet.toTriplet(expression, existsAliasMap.get(expression)); normalizeToSlotMap.put(expression, normalizeToSlotTriplet); } - return new NormalizeToSlotContext(normalizeToSlotMap, currentPlan); + return new NormalizeToSlotContext(normalizeToSlotMap); } /** normalizeToUseSlotRef, no custom normalize */ - public List normalizeToUseSlotRef( - List expressions, boolean isOutput) { - return normalizeToUseSlotRef(expressions, (context, expr) -> expr, isOutput); + public List normalizeToUseSlotRef(List expressions) { + return normalizeToUseSlotRef(expressions, (context, expr) -> expr); } /** normalizeToUseSlotRef */ public List normalizeToUseSlotRef(List expressions, - BiFunction customNormalize, - boolean isOutput) { + BiFunction customNormalize) { return expressions.stream() .map(expr -> (E) expr.rewriteDownShortCircuit(child -> { Expression newChild = customNormalize.apply(this, child); if (newChild != null && newChild != child) { return newChild; } - if (child instanceof ScalarFunction && !isOutput) { - return getSlotFromChildOutputsWhichEqualName(currentPlan, child); - } - if (child instanceof ScalarFunction && isOutput - && collectSlotFromChildOutputsWhichEqualName(currentPlan, child).isPresent()) { - NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); - return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.originExpr; - } NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; })).collect(ImmutableList.toImmutableList()); } - private Expression getSlotFromChildOutputsWhichEqualName(Plan repeat, Expression child) { - NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); - Optional slot = collectSlotFromChildOutputsWhichEqualName(repeat, child); - if (slot.isPresent()) { - return slot.get(); - } - return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; - } - - private Optional collectSlotFromChildOutputsWhichEqualName( - Plan repeat, Expression child) { - NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); - String asName = normalizeToSlotTriplet == null - ? child.toSql() : normalizeToSlotTriplet.remainExpr.getName(); - return repeat.child(0).getOutput().stream() - .filter(s -> s.getName().equals(asName)) - .findFirst(); - } - /** * generate bottom projections with groupByExpressions. * eg: * groupByExpressions: k1#0, k2#1 + 1; * bottom: k1#0, (k2#1 + 1) AS (k2 + 1)#2; */ - public Set pushDownToNamedExpression( - Collection needToPushExpressions, Plan current) { + public Set pushDownToNamedExpression(Collection needToPushExpressions) { return needToPushExpressions.stream() - .filter(e -> filterScalarFunWithSameAliasNameInChildOutput(e, current)) .map(expr -> { NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(expr); return normalizeToSlotTriplet == null @@ -136,15 +99,6 @@ public Set pushDownToNamedExpression( : normalizeToSlotTriplet.pushedExpr; }).collect(ImmutableSet.toImmutableSet()); } - - private boolean filterScalarFunWithSameAliasNameInChildOutput( - Expression expression, Plan current) { - if (expression instanceof ScalarFunction) { - Optional slot = collectSlotFromChildOutputsWhichEqualName(current, expression); - return !slot.isPresent(); - } - return true; - } } /** NormalizeToSlotTriplet */