diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index bb75713e249e2b..28307ec7b3aa1b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -85,6 +85,7 @@ import org.apache.doris.nereids.rules.rewrite.MergeAggregate; import org.apache.doris.nereids.rules.rewrite.MergeFilters; import org.apache.doris.nereids.rules.rewrite.MergeOneRowRelationIntoUnion; +import org.apache.doris.nereids.rules.rewrite.MergePercentileToArray; import org.apache.doris.nereids.rules.rewrite.MergeProjects; import org.apache.doris.nereids.rules.rewrite.MergeSetOperations; import org.apache.doris.nereids.rules.rewrite.MergeSetOperationsExcept; @@ -404,7 +405,8 @@ public class Rewriter extends AbstractBatchJobExecutor { ), topic("agg rewrite", // these rules should be put after mv optimization to avoid mv matching fail - topDown(new SumLiteralRewrite()) + topDown(new SumLiteralRewrite(), + new MergePercentileToArray()) ), // this rule batch must keep at the end of rewrite to do some plan check topic("Final rewrite and check", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 1090855190cac9..45439e4cd51bf5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -200,6 +200,7 @@ public enum RuleType { REWRITE_REPEAT_EXPRESSION(RuleTypeClass.REWRITE), EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE), REORDER_JOIN(RuleTypeClass.REWRITE), + MERGE_PERCENTILE_TO_ARRAY(RuleTypeClass.REWRITE), // Merge Consecutive plan MERGE_PROJECTS(RuleTypeClass.REWRITE), MERGE_FILTERS(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java new file mode 100644 index 00000000000000..f92ad84bde8525 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java @@ -0,0 +1,218 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.annotation.DependsRules; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Percentile; +import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileArray; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; +import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import com.google.common.collect.Sets.SetView; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/**MergePercentileToArray + * LogicalAggregate (outputExpression:[percentile(a,0.1) as c1, percentile(a,0.22) as c2]) + * -> + * LogicalProject (projects: [element_at(percentile(a,[0.1,0.22])#1, 1) as c1, + * element_at(percentile(a,[0.1,0.22], 2)#1 as c2]) + * --+LogicalAggregate(outputExpression: percentile_array(a, [0.1, 0.22]) as percentile_array(a, [0.1, 0.22])#1) + * */ +@DependsRules({ + NormalizeAggregate.class +}) +public class MergePercentileToArray extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalAggregate(any()) + .then(this::doMerge) + .toRule(RuleType.MERGE_PERCENTILE_TO_ARRAY); + } + + // Merge percentile into percentile_array according to funcMap + private List getPercentileArrays(Map> funcMap) { + List newPercentileArrays = Lists.newArrayList(); + for (Map.Entry> entry : funcMap.entrySet()) { + List literals = new ArrayList<>(); + for (AggregateFunction aggFunc : entry.getValue()) { + List literal = aggFunc.child(1).collectToList(expr -> expr instanceof Literal); + literals.add((Literal) literal.get(0)); + } + ArrayLiteral arrayLiteral = new ArrayLiteral(literals); + PercentileArray percentileArray = null; + if (entry.getKey().isDistinct) { + percentileArray = new PercentileArray(true, entry.getKey().getExpression(), new Cast(arrayLiteral, + ArrayType.of(DoubleType.INSTANCE))); + } else { + percentileArray = new PercentileArray(entry.getKey().getExpression(), new Cast(arrayLiteral, + ArrayType.of(DoubleType.INSTANCE))); + } + newPercentileArrays.add(percentileArray); + } + return newPercentileArrays; + } + + // Find all the percentile functions and place them in the map + // with the first parameter of the percentile as the key + private Map> collectFuncMap(LogicalAggregate aggregate) { + Set aggregateFunctions = aggregate.getAggregateFunctions(); + Map> funcMap = new HashMap<>(); + for (AggregateFunction func : aggregateFunctions) { + if (!(func instanceof Percentile)) { + continue; + } + DistinctAndExpr distictAndExpr = new DistinctAndExpr(func.child(0), func.isDistinct()); + funcMap.computeIfAbsent(distictAndExpr, k -> new ArrayList<>()).add(func); + } + funcMap.entrySet().removeIf(entry -> entry.getValue().size() == 1); + return funcMap; + } + + private Plan doMerge(LogicalAggregate aggregate) { + Map> funcMap = collectFuncMap(aggregate); + if (funcMap.isEmpty()) { + return aggregate; + } + Set canMergePercentiles = Sets.newHashSet(); + for (Map.Entry> entry : funcMap.entrySet()) { + canMergePercentiles.addAll(entry.getValue()); + } + + Set aggregateFunctions = aggregate.getAggregateFunctions(); + SetView aggFuncsNotChange = Sets.difference(aggregateFunctions, canMergePercentiles); + + // construct new Aggregate + List newPercentileArrays = getPercentileArrays(funcMap); + ImmutableList.Builder normalizedAggOutputBuilder = + ImmutableList.builderWithExpectedSize(aggregate.getGroupByExpressions().size() + + aggFuncsNotChange.size() + newPercentileArrays.size()); + List groupBySlots = new ArrayList<>(); + for (Expression groupBy : aggregate.getGroupByExpressions()) { + groupBySlots.add(((NamedExpression) groupBy).toSlot()); + } + normalizedAggOutputBuilder.addAll(groupBySlots); + Set existsAliases = + ExpressionUtils.mutableCollect(aggregate.getOutputExpressions(), Alias.class::isInstance); + NormalizeToSlotContext notChangeFuncContext = NormalizeToSlotContext.buildContext(existsAliases, + aggFuncsNotChange); + NormalizeToSlotContext percentileArrayContext = NormalizeToSlotContext.buildContext(new HashSet<>(), + newPercentileArrays); + normalizedAggOutputBuilder.addAll(notChangeFuncContext.pushDownToNamedExpression(aggFuncsNotChange)); + normalizedAggOutputBuilder.addAll(percentileArrayContext.pushDownToNamedExpression(newPercentileArrays)); + LogicalAggregate newAggregate = aggregate.withAggOutput(normalizedAggOutputBuilder.build()); + + // construct new Project + List notChangeForProject = notChangeFuncContext.normalizeToUseSlotRef( + (Set) (Set) aggFuncsNotChange); + List newPercentileArrayForProject = percentileArrayContext.normalizeToUseSlotRef( + (List) (List) newPercentileArrays); + ImmutableList.Builder newProjectOutputExpressions = ImmutableList.builder(); + newProjectOutputExpressions.addAll((List) (List) notChangeForProject); + Map existsAliasMap = Maps.newHashMap(); + // existsAliasMap is used to keep upper plan refer the same expr + for (Alias alias : existsAliases) { + existsAliasMap.put(alias.child(), alias); + } + Map slotMap = Maps.newHashMap(); + // slotMap is used to find the correspondence + // between LogicalProject's element_at(percentile_array_slot_reference, i) which replaces the old percentile() + // and the merged percentile_array() in LogicalAggregate + for (int i = 0; i < newPercentileArrays.size(); i++) { + DistinctAndExpr distinctAndExpr = new DistinctAndExpr(newPercentileArrays.get(i) + .child(0), newPercentileArrays.get(i).isDistinct()); + slotMap.put(distinctAndExpr, (Slot) newPercentileArrayForProject.get(i)); + } + for (Map.Entry> entry : funcMap.entrySet()) { + for (int i = 0; i < entry.getValue().size(); i++) { + AggregateFunction aggFunc = entry.getValue().get(i); + Alias originAlias = existsAliasMap.get(aggFunc); + DistinctAndExpr distinctAndExpr = new DistinctAndExpr(aggFunc.child(0), aggFunc.isDistinct()); + Alias newAlias = new Alias(originAlias.getExprId(), new ElementAt(slotMap.get(distinctAndExpr), + new IntegerLiteral(i + 1)), originAlias.getName()); + newProjectOutputExpressions.add(newAlias); + } + } + newProjectOutputExpressions.addAll(groupBySlots); + return new LogicalProject(newProjectOutputExpressions.build(), newAggregate); + } + + private static class DistinctAndExpr { + private Expression expression; + private boolean isDistinct; + + public DistinctAndExpr(Expression expression, boolean isDistinct) { + this.expression = expression; + this.isDistinct = isDistinct; + } + + public Expression getExpression() { + return expression; + } + + public boolean isDistinct() { + return isDistinct; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DistinctAndExpr a = (DistinctAndExpr) o; + return isDistinct == a.isDistinct + && Objects.equals(expression, a.expression); + } + + @Override + public int hashCode() { + return Objects.hash(expression, isDistinct); + } + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArrayTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArrayTest.java new file mode 100644 index 00000000000000..224db4fcac54be --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArrayTest.java @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Test; + +public class MergePercentileToArrayTest extends TestWithFeService implements MemoPatternMatchSupported { + @Override + protected void runBeforeAll() throws Exception { + createDatabase("merge_percentile_to_array"); + createTable( + "create table merge_percentile_to_array.t (\n" + + "pk int, a int, b int\n" + + ")\n" + + "distributed by hash(pk) buckets 10\n" + + "properties('replication_num' = '1');" + ); + connectContext.setDatabase("default_cluster:merge_percentile_to_array"); + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + } + + @Test + void eliminateMax() { + String sql = "select sum(a), percentile(pk, 0.1) as c1, percentile(pk, 0.2) as c2 from t group by b;"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalProject(logicalAggregate(any())).when(p -> + p.getProjects().get(1).toSql().contains("element_at(percentile_array") + && p.getProjects().get(2).toSql().contains("element_at(percentile_array")) + ); + } +} + diff --git a/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out b/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out new file mode 100644 index 00000000000000..b495302e80d3c8 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out @@ -0,0 +1,43 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !merge_two -- +52 1.0 2.0 + +-- !merge_three -- +52 1.0 2.0 2.0 + +-- !merge_two_group -- +52 1.0 2.0 1.0 3.25 + +-- !merge_two_group -- +52 1.0 1.0 1.0 3.25 + +-- !no_merge -- +52 1.0 + +-- !with_group_by -- +1 2.0 2.0 +1 2.0 2.0 +18 1.8 2.6 +2 1.1 1.2 +3 2.0 2.0 +3 2.0 2.0 +5 1.1 1.2 +5 3.0 3.0 +7 3.0 3.0 +7 6.0 6.0 + +-- !with_upper_refer -- +1.0 3.25 + +-- !with_expr -- +2.0 3.25 + +-- !no_other_agg_func -- +2.0 1.0 1 +2.2 3.0 3 +3.0 4.0 4 +4.0 2.0 2 +4.0 5.0 5 +7.0 \N \N +7.0 7.0 7 + diff --git a/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy b/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy new file mode 100644 index 00000000000000..2071d75ae85d4e --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("merge_percentile_to_array") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql """ + DROP TABLE IF EXISTS test_merge_percentile + """ + + sql """ + create table test_merge_percentile(pk int, a int, b int) distributed by hash(pk) buckets 10 + properties('replication_num' = '1'); + """ + + sql """ + insert into test_merge_percentile values(2,1,3),(1,1,2),(3,5,6),(6,null,6),(4,5,6),(2,1,4),(2,3,5),(1,1,4) + ,(3,5,6),(3,5,null),(6,7,1),(2,1,7),(2,4,2),(2,3,9),(1,3,6),(3,5,8),(3,2,8); + """ + + order_qt_merge_two "select sum(a),percentile(pk, 0.1) as c1 , percentile(pk, 0.2) as c2 from test_merge_percentile;" + order_qt_merge_three """select sum(a),percentile(pk, 0.1) as c1 , percentile(pk, 0.2) as c2 , + percentile(pk, 0.4) as c2 from test_merge_percentile;""" + order_qt_merge_two_group """select sum(a),percentile(pk, 0.1) , percentile(pk, 0.2), + percentile(a, 0.1),percentile(a, 0.55) as c2 from test_merge_percentile;""" + order_qt_merge_two_group """select sum(a),percentile(pk, 0.1) as c1 , percentile(a, 0.2) c2, + percentile(pk, 0.1) c3, percentile(a, 0.55) as c4 from test_merge_percentile;""" + order_qt_no_merge "select sum(a),percentile(pk, 0.1) from test_merge_percentile;" + order_qt_with_group_by """select sum(a),percentile(pk, 0.1) as c1 , percentile(pk, 0.2) as c2 + from test_merge_percentile group by b;""" + + order_qt_with_upper_refer """select c1, c2 from ( + select sum(a),percentile(pk, 0.1) as c1 , percentile(a, 0.2),percentile(pk, 0.1), + percentile(a, 0.55) as c2 from test_merge_percentile) t; + """ + order_qt_with_expr """ + select c1, c2 from ( + select sum(a),percentile(pk+1, 0.1) as c1 , percentile(abs(a), 0.2),percentile(pk+1, 0.3), + percentile(abs(a), 0.55) as c2 from test_merge_percentile) t; + """ + + order_qt_no_other_agg_func """select c1, c2, a from ( + select a, percentile(pk+1, 0.1) as c1 , percentile(abs(a), 0.2),percentile(pk+1, 0.3), + percentile(abs(a), 0.55) as c2 from test_merge_percentile group by a) t; + """ + +} \ No newline at end of file