From 939acd869a7ac18681302c61907d5a4f53a8d0ec Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Wed, 2 Apr 2025 20:06:08 +0800 Subject: [PATCH 1/4] fix --- .../rules/rewrite/MergePercentileToArray.java | 6 ++- .../merge_percentile_to_array.out | 40 +++++++++++++++++++ .../merge_percentile_to_array.groovy | 8 ++++ 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java index fe81adf13bf29d..a45c9c9215c3c9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java @@ -79,8 +79,7 @@ private List getPercentileArrays(Map> entry : funcMap.entrySet()) { List literals = new ArrayList<>(); for (AggregateFunction aggFunc : entry.getValue()) { - List literal = aggFunc.child(1).collectToList(expr -> expr instanceof Literal); - literals.add((Literal) literal.get(0)); + literals.add((Literal) aggFunc.child(1)); } ArrayLiteral arrayLiteral = new ArrayLiteral(literals); PercentileArray percentileArray = null; @@ -105,6 +104,9 @@ private Map> collectFuncMap(LogicalAggr if (!(func instanceof Percentile)) { continue; } + if (!(func.child(1) instanceof Literal)) { + continue; + } DistinctAndExpr distictAndExpr = new DistinctAndExpr(func.child(0), func.isDistinct()); funcMap.computeIfAbsent(distictAndExpr, k -> new ArrayList<>()).add(func); } diff --git a/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out b/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out index 1b2f876cfba50a..02026ee3f2fb28 100644 --- a/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out +++ b/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out @@ -53,3 +53,43 @@ 7 6.0 6.0 6.0 9 1.2 1.2 1.8 +-- !grouping -- +\N 1.0 4.800000000000001 +\N 6.0 6.0 +\N 6.0 6.0 +1 1.0 1.0 +1 1.0 2.0 +1 1.1 1.9 +1 2.0 2.0 +1 2.0 2.0 +2 3.0 3.0 +2 3.0 3.0 +3 1.0 1.0 +3 1.2 2.0 +3 2.0 2.0 +3 2.0 2.0 +4 2.0 2.0 +4 2.0 2.0 +5 3.0 3.0 +5 3.0 3.0 +5 3.0 3.6 +5 3.0 3.8 +7 6.0 6.0 +7 6.0 6.0 + +-- !skip_fold -- +\N 6 6.0 6.0 6.0 +1 2 1.0 1.0 1.0 +1 3 2.0 2.0 2.0 +1 4 1.3 1.9 1.6 +1 7 2.0 2.0 2.0 +2 8 3.0 3.0 3.0 +3 5 2.0 2.0 2.0 +3 6 1.0 1.0 1.0 +3 9 2.0 2.0 2.0 +4 2 2.0 2.0 2.0 +5 \N 3.0 3.0 3.0 +5 6 3.0 3.8 3.2 +5 8 3.0 3.0 3.0 +7 1 6.0 6.0 6.0 + diff --git a/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy b/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy index 5bb13c6336c264..3326f62a6b605c 100644 --- a/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy +++ b/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy @@ -61,4 +61,12 @@ suite("merge_percentile_to_array") { percentile(pk, 0.4) as c2 from test_merge_percentile;""" order_qt_same_percentile_group_by """select sum(a),percentile(pk, 0.1) as c1 , percentile(pk, 0.1) as c2 , percentile(pk, 0.4) as c2 from test_merge_percentile group by a;""" + + order_qt_grouping """ + select a,percentile(pk, 0.1),percentile(pk, 0.9) from test_merge_percentile group by grouping sets((a,b),(a),()) + """ + sql "set debug_skip_fold_constant=true;" + order_qt_skip_fold """ + select a,b,percentile(pk, 0.1+0.2),percentile(pk, 0.9),percentile(pk, 0.6) from test_merge_percentile group by a,b + """ } \ No newline at end of file From 9ea97f92e0b66a822709980a1acb5eeff14ef4c9 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Fri, 18 Apr 2025 17:11:17 +0800 Subject: [PATCH 2/4] add grouping support --- .../rules/rewrite/MergePercentileToArray.java | 19 ++-- .../rewrite/MergePercentileToArrayTest.java | 13 +++ .../merge_percentile_to_array.out | 92 +++++++++++++++++++ .../merge_percentile_to_array.groovy | 13 +++ 4 files changed, 126 insertions(+), 11 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java index a45c9c9215c3c9..14f7defe8d28c2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java @@ -30,10 +30,9 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Percentile; import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileArray; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Array; import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; -import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; -import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -76,18 +75,19 @@ public Rule build() { // Merge percentile into percentile_array according to funcMap private List getPercentileArrays(Map> funcMap) { List newPercentileArrays = Lists.newArrayList(); + for (Map.Entry> entry : funcMap.entrySet()) { - List literals = new ArrayList<>(); + List literals = new ArrayList<>(); for (AggregateFunction aggFunc : entry.getValue()) { - literals.add((Literal) aggFunc.child(1)); + literals.add(aggFunc.child(1)); } - ArrayLiteral arrayLiteral = new ArrayLiteral(literals); - PercentileArray percentileArray = null; + Array array = new Array(literals.toArray(new Expression[0])); + PercentileArray percentileArray; if (entry.getKey().isDistinct) { - percentileArray = new PercentileArray(true, entry.getKey().getExpression(), new Cast(arrayLiteral, + percentileArray = new PercentileArray(true, entry.getKey().getExpression(), new Cast(array, ArrayType.of(DoubleType.INSTANCE))); } else { - percentileArray = new PercentileArray(entry.getKey().getExpression(), new Cast(arrayLiteral, + percentileArray = new PercentileArray(entry.getKey().getExpression(), new Cast(array, ArrayType.of(DoubleType.INSTANCE))); } newPercentileArrays.add(percentileArray); @@ -104,9 +104,6 @@ private Map> collectFuncMap(LogicalAggr if (!(func instanceof Percentile)) { continue; } - if (!(func.child(1) instanceof Literal)) { - continue; - } DistinctAndExpr distictAndExpr = new DistinctAndExpr(func.child(0), func.isDistinct()); funcMap.computeIfAbsent(distictAndExpr, k -> new ArrayList<>()).add(func); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArrayTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArrayTest.java index 224db4fcac54be..47597db1af3db9 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArrayTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArrayTest.java @@ -51,5 +51,18 @@ void eliminateMax() { && p.getProjects().get(2).toSql().contains("element_at(percentile_array")) ); } + + @Test + void testGrouping() { + String sql = "SELECT percentile(a, 0.11), percentile(a,0.25+0.1) as percentiles,sum(a) FROM t group by grouping sets((b),())"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalProject(logicalAggregate()).when(p -> + p.getProjects().get(0).toSql().contains("element_at(percentile_array") + && p.getProjects().get(1).toSql().contains("element_at(percentile_array")) + ); + } } diff --git a/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out b/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out index 02026ee3f2fb28..87c516ec59356b 100644 --- a/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out +++ b/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out @@ -93,3 +93,95 @@ 5 8 3.0 3.0 3.0 7 1 6.0 6.0 6.0 +-- !grouping_skip_fold -- +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.33 1.75 +2.33 2.75 +3.0 3.0 +3.0 3.0 +3.32 5.0 +3.66 4.5 +5.0 5.0 +5.0 5.0 +7.0 7.0 +7.0 7.0 + +-- !grouping_expr -- +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.0 2.25 +1.33 2.05 +2.33 3.05 +3.0 3.0 +3.0 3.0 +3.32 5.0 +3.66 5.0 +5.0 5.0 +5.0 5.0 +7.0 7.0 +7.0 7.0 + +-- !grouping_expr_other_agg -- +\N \N \N +1.0 1.0 1 +1.0 1.0 1 +1.0 1.0 1 +1.0 1.0 1 +1.0 2.25 1 +1.33 2.05 1 +2.0 2.0 2 +2.33 3.05 2 +3.0 3.0 3 +3.0 3.0 3 +3.0 3.0 3 +3.66 5.0 3 +4.0 4.0 4 +5.0 5.0 5 +5.0 5.0 5 +7.0 7.0 7 +7.0 7.0 7 + +-- !grouping_expr_other_agg_upper_ref -- +1.0 1 +1.0 1 +1.0 1 +1.0 1 +1.0 1 +1.0 1 +1.33 1 +2.33 2 +3.0 3 +3.0 3 +3.32 2 +3.66 3 +5.0 5 +5.0 5 +7.0 7 +7.0 7 + +-- !grouping_expr_other_agg_upper_ref_multi_transform -- +1.0 1.0 1.0 +1.0 1.0 2.25 +1.33 1.33 2.05 +2.33 2.33 3.05 +3.0 3.0 3.0 +4.33 3.575 5.0 +7.0 7.0 7.0 + +-- !grouping_multi_merge -- +1.0 1.0 1.0 1.0 +1.0 1.0 2.25 2.25 +1.33 1.33 2.05 2.05 +2.33 2.33 3.05 3.05 +3.0 3.0 3.0 3.0 +4.33 3.575 5.0 5.0 +7.0 7.0 7.0 7.0 + diff --git a/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy b/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy index 3326f62a6b605c..acbeaeb87915f8 100644 --- a/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy +++ b/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy @@ -69,4 +69,17 @@ suite("merge_percentile_to_array") { order_qt_skip_fold """ select a,b,percentile(pk, 0.1+0.2),percentile(pk, 0.9),percentile(pk, 0.6) from test_merge_percentile group by a,b """ + order_qt_grouping_skip_fold "SELECT percentile(a, 0.11), percentile(a,0.25) as percentiles FROM test_merge_percentile ts group by grouping sets((b),(pk),())" + order_qt_grouping_expr "SELECT percentile(a, 0.11), percentile(a,0.25+0.1) as percentiles FROM test_merge_percentile ts group by grouping sets((b),(pk),())" + order_qt_grouping_expr_other_agg "SELECT percentile(a, 0.11), percentile(a,0.25+0.1) as percentiles, min(a) FROM test_merge_percentile ts group by grouping sets((b),(a),())" + order_qt_grouping_expr_other_agg_upper_ref """select c1,c3 from (SELECT percentile(a, 0.11) c1, percentile(a,0.25+0.1) as c2, min(a) c3 + FROM test_merge_percentile ts group by grouping sets((b),(pk),())) t""" + order_qt_grouping_expr_other_agg_upper_ref_multi_transform """select percentile(c1,0.5),percentile(c1,0.25), c2 from + (SELECT percentile(a, 0.11) c1, percentile(a,0.25+0.1) as c2, min(a) c3 + FROM test_merge_percentile ts group by grouping sets((b),(pk),())) t group by c2""" + order_qt_grouping_multi_merge """ + select percentile(c1,0.5), percentile(c1,0.25), percentile(c2,0.1),percentile(c2,0.1+0.6) from + (SELECT percentile(a, 0.11) c1, percentile(a,0.25+0.1) as c2, min(a) c3 + FROM test_merge_percentile ts group by grouping sets((b),(pk),())) t group by c2 + """ } \ No newline at end of file From 57d13e3f8f4a08eb7fe768d173eb41bd6039a572 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Thu, 1 May 2025 15:49:01 +0800 Subject: [PATCH 3/4] cast to ArrayLiteral if can --- .../rules/rewrite/MergePercentileToArray.java | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java index 14f7defe8d28c2..a2a1bb6c8a40d3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java @@ -32,7 +32,9 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileArray; import org.apache.doris.nereids.trees.expressions.functions.scalar.Array; import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; +import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -77,18 +79,31 @@ private List getPercentileArrays(Map newPercentileArrays = Lists.newArrayList(); for (Map.Entry> entry : funcMap.entrySet()) { - List literals = new ArrayList<>(); + List percentList = new ArrayList<>(); + boolean allPercentIsLiteral = true; for (AggregateFunction aggFunc : entry.getValue()) { - literals.add(aggFunc.child(1)); + Expression percent = aggFunc.child(1); + percentList.add(percent); + if (allPercentIsLiteral && !(percent instanceof Literal)) { + allPercentIsLiteral = false; + } + } + ArrayLiteral percentArrayLiteral = null; + Array percentArray = null; + if (allPercentIsLiteral) { + percentArrayLiteral = new ArrayLiteral((List) percentList); + } else { + percentArray = new Array(percentList.toArray(new Expression[0])); } - Array array = new Array(literals.toArray(new Expression[0])); + PercentileArray percentileArray; + Expression secondArg = allPercentIsLiteral + ? new Cast(percentArrayLiteral, ArrayType.of(DoubleType.INSTANCE)) + : new Cast(percentArray, ArrayType.of(DoubleType.INSTANCE)); if (entry.getKey().isDistinct) { - percentileArray = new PercentileArray(true, entry.getKey().getExpression(), new Cast(array, - ArrayType.of(DoubleType.INSTANCE))); + percentileArray = new PercentileArray(true, entry.getKey().getExpression(), secondArg); } else { - percentileArray = new PercentileArray(entry.getKey().getExpression(), new Cast(array, - ArrayType.of(DoubleType.INSTANCE))); + percentileArray = new PercentileArray(entry.getKey().getExpression(), secondArg); } newPercentileArrays.add(percentileArray); } From 9b6e95599a656ad6779978fc7ce44cfe3507163c Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Thu, 1 May 2025 15:59:33 +0800 Subject: [PATCH 4/4] use castIfNotSameType --- .../doris/nereids/rules/rewrite/MergePercentileToArray.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java index a2a1bb6c8a40d3..a791d7370807ce 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java @@ -23,7 +23,6 @@ import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext; import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; @@ -41,6 +40,7 @@ import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -98,8 +98,8 @@ private List getPercentileArrays(Map