diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java index fe81adf13bf29d..a791d7370807ce 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArray.java @@ -23,13 +23,13 @@ import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext; import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Percentile; import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileArray; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Array; import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; @@ -40,6 +40,7 @@ import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -76,20 +77,33 @@ public Rule build() { // Merge percentile into percentile_array according to funcMap private List getPercentileArrays(Map> funcMap) { List newPercentileArrays = Lists.newArrayList(); + for (Map.Entry> entry : funcMap.entrySet()) { - List literals = new ArrayList<>(); + List percentList = new ArrayList<>(); + boolean allPercentIsLiteral = true; for (AggregateFunction aggFunc : entry.getValue()) { - List literal = aggFunc.child(1).collectToList(expr -> expr instanceof Literal); - literals.add((Literal) literal.get(0)); + Expression percent = aggFunc.child(1); + percentList.add(percent); + if (allPercentIsLiteral && !(percent instanceof Literal)) { + allPercentIsLiteral = false; + } } - ArrayLiteral arrayLiteral = new ArrayLiteral(literals); - PercentileArray percentileArray = null; + ArrayLiteral percentArrayLiteral = null; + Array percentArray = null; + if (allPercentIsLiteral) { + percentArrayLiteral = new ArrayLiteral((List) percentList); + } else { + percentArray = new Array(percentList.toArray(new Expression[0])); + } + + PercentileArray percentileArray; + Expression secondArg = allPercentIsLiteral + ? TypeCoercionUtils.castIfNotSameType(percentArrayLiteral, ArrayType.of(DoubleType.INSTANCE)) + : TypeCoercionUtils.castIfNotSameType(percentArray, ArrayType.of(DoubleType.INSTANCE)); if (entry.getKey().isDistinct) { - percentileArray = new PercentileArray(true, entry.getKey().getExpression(), new Cast(arrayLiteral, - ArrayType.of(DoubleType.INSTANCE))); + percentileArray = new PercentileArray(true, entry.getKey().getExpression(), secondArg); } else { - percentileArray = new PercentileArray(entry.getKey().getExpression(), new Cast(arrayLiteral, - ArrayType.of(DoubleType.INSTANCE))); + percentileArray = new PercentileArray(entry.getKey().getExpression(), secondArg); } newPercentileArrays.add(percentileArray); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArrayTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArrayTest.java index 224db4fcac54be..47597db1af3db9 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArrayTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergePercentileToArrayTest.java @@ -51,5 +51,18 @@ void eliminateMax() { && p.getProjects().get(2).toSql().contains("element_at(percentile_array")) ); } + + @Test + void testGrouping() { + String sql = "SELECT percentile(a, 0.11), percentile(a,0.25+0.1) as percentiles,sum(a) FROM t group by grouping sets((b),())"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalProject(logicalAggregate()).when(p -> + p.getProjects().get(0).toSql().contains("element_at(percentile_array") + && p.getProjects().get(1).toSql().contains("element_at(percentile_array")) + ); + } } diff --git a/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out b/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out index 1b2f876cfba50a..87c516ec59356b 100644 --- a/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out +++ b/regression-test/data/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.out @@ -53,3 +53,135 @@ 7 6.0 6.0 6.0 9 1.2 1.2 1.8 +-- !grouping -- +\N 1.0 4.800000000000001 +\N 6.0 6.0 +\N 6.0 6.0 +1 1.0 1.0 +1 1.0 2.0 +1 1.1 1.9 +1 2.0 2.0 +1 2.0 2.0 +2 3.0 3.0 +2 3.0 3.0 +3 1.0 1.0 +3 1.2 2.0 +3 2.0 2.0 +3 2.0 2.0 +4 2.0 2.0 +4 2.0 2.0 +5 3.0 3.0 +5 3.0 3.0 +5 3.0 3.6 +5 3.0 3.8 +7 6.0 6.0 +7 6.0 6.0 + +-- !skip_fold -- +\N 6 6.0 6.0 6.0 +1 2 1.0 1.0 1.0 +1 3 2.0 2.0 2.0 +1 4 1.3 1.9 1.6 +1 7 2.0 2.0 2.0 +2 8 3.0 3.0 3.0 +3 5 2.0 2.0 2.0 +3 6 1.0 1.0 1.0 +3 9 2.0 2.0 2.0 +4 2 2.0 2.0 2.0 +5 \N 3.0 3.0 3.0 +5 6 3.0 3.8 3.2 +5 8 3.0 3.0 3.0 +7 1 6.0 6.0 6.0 + +-- !grouping_skip_fold -- +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.33 1.75 +2.33 2.75 +3.0 3.0 +3.0 3.0 +3.32 5.0 +3.66 4.5 +5.0 5.0 +5.0 5.0 +7.0 7.0 +7.0 7.0 + +-- !grouping_expr -- +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.0 1.0 +1.0 2.25 +1.33 2.05 +2.33 3.05 +3.0 3.0 +3.0 3.0 +3.32 5.0 +3.66 5.0 +5.0 5.0 +5.0 5.0 +7.0 7.0 +7.0 7.0 + +-- !grouping_expr_other_agg -- +\N \N \N +1.0 1.0 1 +1.0 1.0 1 +1.0 1.0 1 +1.0 1.0 1 +1.0 2.25 1 +1.33 2.05 1 +2.0 2.0 2 +2.33 3.05 2 +3.0 3.0 3 +3.0 3.0 3 +3.0 3.0 3 +3.66 5.0 3 +4.0 4.0 4 +5.0 5.0 5 +5.0 5.0 5 +7.0 7.0 7 +7.0 7.0 7 + +-- !grouping_expr_other_agg_upper_ref -- +1.0 1 +1.0 1 +1.0 1 +1.0 1 +1.0 1 +1.0 1 +1.33 1 +2.33 2 +3.0 3 +3.0 3 +3.32 2 +3.66 3 +5.0 5 +5.0 5 +7.0 7 +7.0 7 + +-- !grouping_expr_other_agg_upper_ref_multi_transform -- +1.0 1.0 1.0 +1.0 1.0 2.25 +1.33 1.33 2.05 +2.33 2.33 3.05 +3.0 3.0 3.0 +4.33 3.575 5.0 +7.0 7.0 7.0 + +-- !grouping_multi_merge -- +1.0 1.0 1.0 1.0 +1.0 1.0 2.25 2.25 +1.33 1.33 2.05 2.05 +2.33 2.33 3.05 3.05 +3.0 3.0 3.0 3.0 +4.33 3.575 5.0 5.0 +7.0 7.0 7.0 7.0 + diff --git a/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy b/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy index 5bb13c6336c264..acbeaeb87915f8 100644 --- a/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy +++ b/regression-test/suites/nereids_rules_p0/merge_percentile_to_array/merge_percentile_to_array.groovy @@ -61,4 +61,25 @@ suite("merge_percentile_to_array") { percentile(pk, 0.4) as c2 from test_merge_percentile;""" order_qt_same_percentile_group_by """select sum(a),percentile(pk, 0.1) as c1 , percentile(pk, 0.1) as c2 , percentile(pk, 0.4) as c2 from test_merge_percentile group by a;""" + + order_qt_grouping """ + select a,percentile(pk, 0.1),percentile(pk, 0.9) from test_merge_percentile group by grouping sets((a,b),(a),()) + """ + sql "set debug_skip_fold_constant=true;" + order_qt_skip_fold """ + select a,b,percentile(pk, 0.1+0.2),percentile(pk, 0.9),percentile(pk, 0.6) from test_merge_percentile group by a,b + """ + order_qt_grouping_skip_fold "SELECT percentile(a, 0.11), percentile(a,0.25) as percentiles FROM test_merge_percentile ts group by grouping sets((b),(pk),())" + order_qt_grouping_expr "SELECT percentile(a, 0.11), percentile(a,0.25+0.1) as percentiles FROM test_merge_percentile ts group by grouping sets((b),(pk),())" + order_qt_grouping_expr_other_agg "SELECT percentile(a, 0.11), percentile(a,0.25+0.1) as percentiles, min(a) FROM test_merge_percentile ts group by grouping sets((b),(a),())" + order_qt_grouping_expr_other_agg_upper_ref """select c1,c3 from (SELECT percentile(a, 0.11) c1, percentile(a,0.25+0.1) as c2, min(a) c3 + FROM test_merge_percentile ts group by grouping sets((b),(pk),())) t""" + order_qt_grouping_expr_other_agg_upper_ref_multi_transform """select percentile(c1,0.5),percentile(c1,0.25), c2 from + (SELECT percentile(a, 0.11) c1, percentile(a,0.25+0.1) as c2, min(a) c3 + FROM test_merge_percentile ts group by grouping sets((b),(pk),())) t group by c2""" + order_qt_grouping_multi_merge """ + select percentile(c1,0.5), percentile(c1,0.25), percentile(c2,0.1),percentile(c2,0.1+0.6) from + (SELECT percentile(a, 0.11) c1, percentile(a,0.25+0.1) as c2, min(a) c3 + FROM test_merge_percentile ts group by grouping sets((b),(pk),())) t group by c2 + """ } \ No newline at end of file