From 16b202c74576e9f620683e6dedc1a129cc07a0cc Mon Sep 17 00:00:00 2001 From: yujun Date: Wed, 13 Aug 2025 15:22:38 +0800 Subject: [PATCH 1/5] add project for unique function --- .../doris/nereids/jobs/executor/Rewriter.java | 7 + .../apache/doris/nereids/rules/RuleType.java | 1 + .../rewrite/AddProjectForUniqueFunction.java | 259 ++++++++++++++++++ .../rules/expression/SimplifyRangeTest.java | 12 +- .../add_project_for_unique_function.out | 52 ++++ .../add_project_for_unique_function.groovy | 61 +++++ .../unique_function/load.groovy | 13 + 7 files changed, 403 insertions(+), 2 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java create mode 100644 regression-test/data/nereids_rules_p0/unique_function/add_project_for_unique_function.out create mode 100644 regression-test/suites/nereids_rules_p0/unique_function/add_project_for_unique_function.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 447c58311a903d..99ad405bfa87cd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -34,6 +34,7 @@ import org.apache.doris.nereids.rules.expression.QueryColumnCollector; import org.apache.doris.nereids.rules.rewrite.AddDefaultLimit; import org.apache.doris.nereids.rules.rewrite.AddProjectForJoin; +import org.apache.doris.nereids.rules.rewrite.AddProjectForUniqueFunction; import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType; import org.apache.doris.nereids.rules.rewrite.AdjustNullable; import org.apache.doris.nereids.rules.rewrite.AggScalarSubQueryToWindowFunction; @@ -738,6 +739,12 @@ public class Rewriter extends AbstractBatchJobExecutor { topDown(new SumLiteralRewrite(), new MergePercentileToArray()) ), + topic("add projection for unique function", + // separate AddProjectForUniqueFunction and MergeProjectable + // to avoid dead loop if code has bug + topDown(new AddProjectForUniqueFunction()), + topDown(new MergeProjectable()) + ), topic("collect scan filter for hbo", // this rule is to collect filter on basic table for hbo usage topDown(new CollectPredicateOnScan()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 0f2c77b6a1cb95..2c5f81eb32b990 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -214,6 +214,7 @@ public enum RuleType { PUSH_DOWN_DISTINCT_THROUGH_JOIN(RuleTypeClass.REWRITE), ADD_PROJECT_FOR_JOIN(RuleTypeClass.REWRITE), + ADD_PROJECT_FOR_UNIQUE_FUNCTION(RuleTypeClass.REWRITE), VARIANT_SUB_PATH_PRUNING(RuleTypeClass.REWRITE), CLEAR_CONTEXT_STATUS(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java new file mode 100644 index 00000000000000..8940a51620da40 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.scalar.UniqueFunction; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate; +import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.JoinUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; + +/** extract unique function expression which exist multiple times, and add them to a new project child. + * for example: + * before rewrite: filter(random() >= 5 and random() <= 10), suppose the two random have the same unique expr id. + * after rewrite: filter(k >= 5 and k <= 10) -> project(random() as k) + */ +public class AddProjectForUniqueFunction implements RewriteRuleFactory { + + @Override + public List buildRules() { + return ImmutableList.of( + new GenerateRewrite().build(), + new OneRowRelationRewrite().build(), + new ProjectRewrite().build(), + new FilterRewrite().build(), + new HavingRewrite().build(), + new JoinRewrite().build() + ); + } + + private class GenerateRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalGenerate().thenApply(ctx -> { + LogicalGenerate generate = ctx.root; + Optional, LogicalProject>> + rewrittenOpt = rewriteExpressions(generate, generate.getGenerators()); + if (rewrittenOpt.isPresent()) { + return generate.withGenerators(rewrittenOpt.get().first) + .withChildren(rewrittenOpt.get().second); + } else { + return generate; + } + }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION); + } + } + + private class OneRowRelationRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalOneRowRelation().thenApply(ctx -> { + LogicalOneRowRelation oneRowRelation = ctx.root; + List uniqueFunctionAlias = tryGenUniqueFunctionAlias(oneRowRelation.getProjects()); + if (uniqueFunctionAlias.isEmpty()) { + return oneRowRelation; + } + + Map replaceMap = Maps.newHashMap(); + for (NamedExpression alias : uniqueFunctionAlias) { + replaceMap.put(alias.child(0), alias.toSlot()); + } + ImmutableList.Builder newProjectBuilder + = ImmutableList.builderWithExpectedSize(oneRowRelation.getProjects().size()); + for (NamedExpression expr : oneRowRelation.getProjects()) { + newProjectBuilder.add((NamedExpression) ExpressionUtils.replace(expr, replaceMap)); + } + return new LogicalProject<>( + newProjectBuilder.build(), + oneRowRelation.withProjects(uniqueFunctionAlias)); + }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION); + } + } + + private class ProjectRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalProject().thenApply(ctx -> { + LogicalProject project = ctx.root; + Optional, LogicalProject>> + rewrittenOpt = rewriteExpressions(project, project.getProjects()); + if (rewrittenOpt.isPresent()) { + return project.withProjectsAndChild(rewrittenOpt.get().first, rewrittenOpt.get().second); + } else { + return project; + } + }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION); + } + } + + private class FilterRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalFilter().thenApply(ctx -> { + LogicalFilter filter = ctx.root; + Optional, LogicalProject>> + rewrittenOpt = rewriteExpressions(filter, filter.getConjuncts()); + if (rewrittenOpt.isPresent()) { + return filter.withConjunctsAndChild( + ImmutableSet.copyOf(rewrittenOpt.get().first), + rewrittenOpt.get().second); + } else { + return filter; + } + }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION); + } + } + + private class HavingRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalHaving().thenApply(ctx -> { + LogicalHaving having = ctx.root; + Optional, LogicalProject>> + rewrittenOpt = rewriteExpressions(having, having.getConjuncts()); + if (rewrittenOpt.isPresent()) { + return having.withConjuncts(ImmutableSet.copyOf(rewrittenOpt.get().first)) + .withChildren(rewrittenOpt.get().second); + } else { + return having; + } + }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION); + } + } + + private class JoinRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalJoin().thenApply(ctx -> { + LogicalJoin join = ctx.root; + int hashOtherConjunctsSize = join.getHashJoinConjuncts().size() + join.getOtherJoinConjuncts().size(); + int totalConjunctsSize = hashOtherConjunctsSize + join.getMarkJoinConjuncts().size(); + List allConjuncts = Lists.newArrayListWithExpectedSize(totalConjunctsSize); + allConjuncts.addAll(join.getHashJoinConjuncts()); + allConjuncts.addAll(join.getOtherJoinConjuncts()); + allConjuncts.addAll(join.getMarkJoinConjuncts()); + Optional, LogicalProject>> rewrittenOpt + = rewriteExpressions(join, allConjuncts); + if (!rewrittenOpt.isPresent()) { + return join; + } + + LogicalProject newLeftChild = rewrittenOpt.get().second; + List newAllConjuncts = rewrittenOpt.get().first; + List newHashOtherConjuncts = newAllConjuncts.subList(0, hashOtherConjunctsSize); + List newMarkJoinConjuncts = ImmutableList.copyOf( + newAllConjuncts.subList(hashOtherConjunctsSize, totalConjunctsSize)); + // TODO: code from FindHashConditionForJoin + Pair, List> pair = JoinUtils.extractExpressionForHashTable( + newLeftChild.getOutput(), join.right().getOutput(), newHashOtherConjuncts); + List newHashJoinConjuncts = pair.first; + List newOtherJoinConjuncts = pair.second; + JoinType joinType = join.getJoinType(); + if (joinType == JoinType.CROSS_JOIN && !newHashJoinConjuncts.isEmpty()) { + joinType = JoinType.INNER_JOIN; + } + return new LogicalJoin<>(joinType, + newHashJoinConjuncts, + newOtherJoinConjuncts, + newMarkJoinConjuncts, + join.getDistributeHint(), + join.getMarkJoinSlotReference(), + ImmutableList.of(newLeftChild, join.right()), + join.getJoinReorderContext()); + }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION); + } + } + + // extract unique function which exist multiple times from targets, + // then alias the unique function and put them into a child project, + // then rewrite targets with the alias names. + private Optional, LogicalProject>> rewriteExpressions( + LogicalPlan plan, Collection targets) { + List uniqueFunctionAlias = tryGenUniqueFunctionAlias(targets); + if (uniqueFunctionAlias.isEmpty()) { + return Optional.empty(); + } + + List projects = ImmutableList.builder() + .addAll(plan.child(0).getOutputSet()) + .addAll(uniqueFunctionAlias) + .build(); + + Map replaceMap = Maps.newHashMap(); + for (NamedExpression alias : uniqueFunctionAlias) { + replaceMap.put(alias.child(0), alias.toSlot()); + } + ImmutableList.Builder newTargetsBuilder = ImmutableList.builderWithExpectedSize(targets.size()); + for (T target : targets) { + newTargetsBuilder.add((T) ExpressionUtils.replace(target, replaceMap)); + } + + return Optional.of(Pair.of(newTargetsBuilder.build(), new LogicalProject<>(projects, plan.child(0)))); + } + + // if a unique function exists multiple times in the targets, then add a project to alias it. + private List tryGenUniqueFunctionAlias(Collection targets) { + Map unqiueFunctionCounter = Maps.newLinkedHashMap(); + targets.forEach(target -> target.foreach(e -> { + Expression expr = (Expression) e; + if (expr instanceof UniqueFunction) { + unqiueFunctionCounter.merge((UniqueFunction) expr, 1, Integer::sum); + } + })); + + ImmutableList.Builder builder + = ImmutableList.builderWithExpectedSize(unqiueFunctionCounter.size()); + for (Entry entry : unqiueFunctionCounter.entrySet()) { + if (entry.getValue() > 1) { + ExprId exprId = StatementScopeIdGenerator.newExprId(); + String name = "$_" + entry.getKey().getName() + "_" + exprId.asInt() + "_$"; + builder.add(new Alias(exprId, entry.getKey(), name)); + } + } + + return builder.build(); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java index 36c96179f0fdae..e6e856852d0c13 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java @@ -240,8 +240,8 @@ public void testSimplify() { assertRewrite("SA > '20250101' and SA > '20260110'", "SA > '20260110'"); // random is non-foldable, so the two random(1, 10) are distinct, cann't merge range for them. - Expression expr = rewrite("TA + random(1, 10) > 10 AND TA + random(1, 10) < 1", Maps.newHashMap()); - Assertions.assertEquals("AND[((cast(TA as BIGINT) + random(1, 10)) > 10),((cast(TA as BIGINT) + random(1, 10)) < 1)]", expr.toSql()); + Expression expr = rewriteExpression("X + random(1, 10) > 10 AND X + random(1, 10) < 1", true); + Assertions.assertEquals("AND[((X + random(1, 10)) > 10),((X + random(1, 10)) < 1)]", expr.toSql()); expr = rewrite("TA + random(1, 10) between 10 and 20", Maps.newHashMap()); Assertions.assertEquals("AND[((cast(TA as BIGINT) + random(1, 10)) >= 10),((cast(TA as BIGINT) + random(1, 10)) <= 20)]", expr.toSql()); @@ -446,6 +446,14 @@ private void assertRewriteNotNull(String expression, String expected) { Assertions.assertEquals(expectedExpression, rewrittenExpression); } + private Expression rewriteExpression(String expression, boolean nullable) { + Map mem = Maps.newHashMap(); + Expression needRewriteExpression = PARSER.parseExpression(expression); + needRewriteExpression = nullable ? replaceUnboundSlot(needRewriteExpression, mem) : replaceNotNullUnboundSlot(needRewriteExpression, mem); + needRewriteExpression = typeCoercion(needRewriteExpression); + return executor.rewrite(needRewriteExpression, context); + } + private Expression replaceUnboundSlot(Expression expression, Map mem) { List children = Lists.newArrayList(); boolean hasNewChildren = false; diff --git a/regression-test/data/nereids_rules_p0/unique_function/add_project_for_unique_function.out b/regression-test/data/nereids_rules_p0/unique_function/add_project_for_unique_function.out new file mode 100644 index 00000000000000..2abcf9841a6ce8 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/unique_function/add_project_for_unique_function.out @@ -0,0 +1,52 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !one_row_relation_1 -- +PhysicalResultSink +--PhysicalOneRowRelation[random(1, 100) AS `random(1, 100)`, uuid_to_int(uuid()) AS `uuid_to_int(uuid())`] + +-- !one_row_relation_2 -- +PhysicalResultSink +--PhysicalProject[AND[($_random_5_$ >= 10),($_random_5_$ <= 20)] AS `random(1, 100) between 10 and 20`, uuid_to_int(uuid()) AS `uuid_to_int(uuid())`] +----PhysicalOneRowRelation[random(1, 100) AS `$_random_5_$`] + +-- !one_row_relation_3 -- +PhysicalResultSink +--PhysicalProject[AND[($_random_5_$ >= 10),($_random_5_$ <= 20)] AS `random(1, 100) between 10 and 20`, AND[(uuid_to_int($_uuid_6_$) >= 111),(uuid_to_int($_uuid_6_$) <= 222)] AS `uuid_to_int(uuid()) between 111 and 222`] +----PhysicalOneRowRelation[random(1, 100) AS `$_random_5_$`, uuid() AS `$_uuid_6_$`] + +-- !project_1 -- +PhysicalResultSink +--PhysicalProject[((cast(id as BIGINT) + random(1, 100)) > 20) AS `id + random(1, 100) > 20`, (cast(id as BIGINT) * 200) AS `id * 200`] +----PhysicalOlapScan[t1] + +-- !project_2 -- +PhysicalResultSink +--PhysicalProject[(cast(id as BIGINT) * 200) AS `id * 200`, AND[((cast(id as BIGINT) + $_random_5_$) >= 10),((cast(id as BIGINT) + $_random_5_$) <= 20)] AS `id + random(1, 100) between 10 and 20`] +----PhysicalProject[random(1, 100) AS `$_random_5_$`, t1.id] +------PhysicalOlapScan[t1] + +-- !select_1 -- +PhysicalResultSink +--PhysicalProject[t1.id] +----filter(((cast(id as BIGINT) + random(1, 100)) >= 10)) +------PhysicalOlapScan[t1] + +-- !select_2 -- +PhysicalResultSink +--PhysicalProject[t1.id] +----filter(((cast(id as BIGINT) + $_random_3_$) <= 20) and ((cast(id as BIGINT) + $_random_3_$) >= 10)) +------PhysicalProject[random(1, 100) AS `$_random_3_$`, t1.id] +--------PhysicalOlapScan[t1] + +-- !join_1 -- +PhysicalResultSink +--PhysicalProject[t1.id, t1.msg, t2.id, t2.msg] +----NestedLoopJoin[INNER_JOIN](((cast(id as BIGINT) + cast(id as BIGINT)) + $_random_9_$) >= 10)(((cast(id as BIGINT) + cast(id as BIGINT)) + $_random_9_$) <= 20) +------PhysicalProject[cast(id as BIGINT) AS `cast(id as BIGINT)`, random(1, 100) AS `$_random_9_$`, t1.id, t1.msg] +--------filter(($_random_10_$ <= 10) and ($_random_10_$ >= 1)) +----------PhysicalProject[random(1, 100) AS `$_random_10_$`, t1.id, t1.msg] +------------PhysicalOlapScan[t1] +------PhysicalProject[cast(id as BIGINT) AS `cast(id as BIGINT)`, t2.id, t2.msg] +--------filter(((cast(id as BIGINT) * $_random_11_$) <= 200) and ((cast(id as BIGINT) * $_random_11_$) >= 100)) +----------PhysicalProject[random(1, 100) AS `$_random_11_$`, t2.id, t2.msg] +------------PhysicalOlapScan[t2] + diff --git a/regression-test/suites/nereids_rules_p0/unique_function/add_project_for_unique_function.groovy b/regression-test/suites/nereids_rules_p0/unique_function/add_project_for_unique_function.groovy new file mode 100644 index 00000000000000..d27f9c57ef6cdd --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/unique_function/add_project_for_unique_function.groovy @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite('add_project_for_unique_function') { + sql 'SET enable_nereids_planner=true' + sql 'SET runtime_filter_mode=OFF' + sql 'SET enable_fallback_to_original_planner=false' + sql "SET ignore_shape_nodes='PhysicalDistribute'" + sql "SET detail_shape_nodes='PhysicalProject,PhysicalOneRowRelation'" + sql 'SET disable_nereids_rules=PRUNE_EMPTY_PARTITION' + + // no project + qt_one_row_relation_1 ''' + explain shape plan select random(1, 100), uuid_to_int(uuid()) + ''' + + qt_one_row_relation_2 ''' + explain shape plan select random(1, 100) between 10 and 20, uuid_to_int(uuid()) + ''' + + qt_one_row_relation_3 ''' + explain shape plan select random(1, 100) between 10 and 20, uuid_to_int(uuid()) between 111 and 222 + ''' + + qt_project_1 ''' + explain shape plan select id + random(1, 100) > 20, id * 200 from t1 + ''' + + qt_project_2 ''' + explain shape plan select id + random(1, 100) between 10 and 20, id * 200 from t1 + ''' + + qt_select_1 ''' + explain shape plan select id from t1 where id + random(1, 100) >= 10 + ''' + + qt_select_2 ''' + explain shape plan select id from t1 where id + random(1, 100) between 10 and 20 + ''' + + qt_join_1 ''' + explain shape plan select * from t1 join t2 on + t1.id + t2.id + random(1, 100) between 10 and 20 + and t2.id * random(1, 100) between 100 and 200 + and random(1, 100) between 1 and 10 + ''' +} diff --git a/regression-test/suites/nereids_rules_p0/unique_function/load.groovy b/regression-test/suites/nereids_rules_p0/unique_function/load.groovy index cf4746b4976e5d..65bb565ff6b8da 100644 --- a/regression-test/suites/nereids_rules_p0/unique_function/load.groovy +++ b/regression-test/suites/nereids_rules_p0/unique_function/load.groovy @@ -33,6 +33,19 @@ suite("load") { "replication_allocation" = "tag.location.default: 1" ); """ + sql """ + DROP TABLE IF EXISTS t2 + """ + sql """ + CREATE TABLE IF NOT EXISTS t2( + `id` int(11) NULL, + `msg` text NULL + ) ENGINE = OLAP + DISTRIBUTED BY HASH(id) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ def tbl = "tbl_unique_function_with_one_row" sql "drop table if exists ${tbl} force" From a0a12e45893802a7aefe45da7abd2dd0013e9fc2 Mon Sep 17 00:00:00 2001 From: yujun Date: Mon, 1 Sep 2025 14:58:56 +0800 Subject: [PATCH 2/5] fix between expression --- .../nereids/rules/analysis/BindExpression.java | 9 ++++++--- .../nereids/trees/expressions/Between.java | 17 +++++++++++++++-- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index 915501c8dbe3f4..cc144702fbd540 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -607,9 +607,12 @@ private LogicalJoin bindJoin(MatchingContext Builder otherJoinConjuncts = ImmutableList.builderWithExpectedSize( join.getOtherJoinConjuncts().size()); for (Expression otherJoinConjunct : join.getOtherJoinConjuncts()) { - otherJoinConjunct = analyzer.analyze(otherJoinConjunct); - otherJoinConjunct = TypeCoercionUtils.castIfNotSameType(otherJoinConjunct, BooleanType.INSTANCE); - otherJoinConjuncts.add(otherJoinConjunct); + // after analyzed, 'a between 1 and 10' will rewrite to 'a >= 1 and a <= 10' + Expression boundExpr = analyzer.analyze(otherJoinConjunct); + for (Expression conjunct : ExpressionUtils.extractConjunction(boundExpr)) { + conjunct = TypeCoercionUtils.castIfNotSameType(conjunct, BooleanType.INSTANCE); + otherJoinConjuncts.add(conjunct); + } } return new LogicalJoin<>(join.getJoinType(), hashJoinConjuncts.build(), otherJoinConjuncts.build(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Between.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Between.java index 14cc02dadff046..2c28c47a85ab33 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Between.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Between.java @@ -18,7 +18,6 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.exceptions.UnboundException; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.TernaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BooleanType; @@ -32,7 +31,7 @@ /** * Between predicate expression. */ -public class Between extends Expression implements TernaryExpression, PropagateNullable { +public class Between extends Expression implements TernaryExpression { private final Expression compareExpr; private final Expression lowerBound; @@ -77,6 +76,20 @@ public String toString() { return compareExpr + " BETWEEN " + lowerBound + " AND " + upperBound; } + // nullable is true if any children is nullable, + // but between is not PropagateNullable, + // because FoldConstantRuleOnFE will fold a PropagateNullable expression to NULL if any children is NULL. + // but `4 BETWEEN NULL AND 3` should fold to FALSE, not NULL. + @Override + public boolean nullable() { + for (Expression child : children()) { + if (child.nullable()) { + return true; + } + } + return false; + } + public R accept(ExpressionVisitor visitor, C context) { return visitor.visitBetween(this, context); } From 37a31036d36772e96873b6976f64370c1882b73e Mon Sep 17 00:00:00 2001 From: yujun Date: Tue, 2 Sep 2025 11:09:13 +0800 Subject: [PATCH 3/5] add fe ut --- .../rewrite/AddProjectForUniqueFunction.java | 21 ++- .../AddProjectForUniqueFunctionTest.java | 143 ++++++++++++++++++ 2 files changed, 156 insertions(+), 8 deletions(-) create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunctionTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java index 8940a51620da40..eb9fb2deb61116 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java @@ -40,6 +40,7 @@ import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; @@ -210,7 +211,8 @@ public Rule build() { // extract unique function which exist multiple times from targets, // then alias the unique function and put them into a child project, // then rewrite targets with the alias names. - private Optional, LogicalProject>> rewriteExpressions( + @VisibleForTesting + public Optional, LogicalProject>> rewriteExpressions( LogicalPlan plan, Collection targets) { List uniqueFunctionAlias = tryGenUniqueFunctionAlias(targets); if (uniqueFunctionAlias.isEmpty()) { @@ -235,14 +237,17 @@ private Optional, LogicalProject>> rew } // if a unique function exists multiple times in the targets, then add a project to alias it. - private List tryGenUniqueFunctionAlias(Collection targets) { + @VisibleForTesting + public List tryGenUniqueFunctionAlias(Collection targets) { Map unqiueFunctionCounter = Maps.newLinkedHashMap(); - targets.forEach(target -> target.foreach(e -> { - Expression expr = (Expression) e; - if (expr instanceof UniqueFunction) { - unqiueFunctionCounter.merge((UniqueFunction) expr, 1, Integer::sum); - } - })); + for (Expression target : targets) { + target.foreach(e -> { + Expression expr = (Expression) e; + if (expr instanceof UniqueFunction) { + unqiueFunctionCounter.merge((UniqueFunction) expr, 1, Integer::sum); + } + }); + } ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(unqiueFunctionCounter.size()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunctionTest.java new file mode 100644 index 00000000000000..1e45ce36f9b260 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunctionTest.java @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.hint.DistributeHint; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Random; +import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; +import org.apache.doris.nereids.trees.plans.DistributeType; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Optional; + +public class AddProjectForUniqueFunctionTest implements MemoPatternMatchSupported { + private final LogicalOlapScan studentOlapScan + = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.student); + + @Test + void testGenUniqueFunctionAlias() { + Random random1 = new Random(); + Random random2 = new Random(); + Random random3 = new Random(); + List expressions = ImmutableList.of( + new Add(random1, new Add(random1, new DoubleLiteral(1.0))) , + new Add(random2, random3), + random3); + List namedExpressions = new AddProjectForUniqueFunction().tryGenUniqueFunctionAlias(expressions); + Assertions.assertEquals(2, namedExpressions.size()); + Assertions.assertInstanceOf(Alias.class, namedExpressions.get(0)); + Assertions.assertEquals(((Alias) namedExpressions.get(0)).child(), random1); + Assertions.assertInstanceOf(Alias.class, namedExpressions.get(1)); + Assertions.assertEquals(((Alias) namedExpressions.get(1)).child(), random3); + } + + @Test + void testRewriteExpressionNoChange() { + Random random1 = new Random(); + Random random2 = new Random(); + Random random3 = new Random(); + List projections = ImmutableList.of( + new Alias(new Add(random1, new Add(new DoubleLiteral(1.0), new DoubleLiteral(1.0)))), + new Alias(new Add(random2, new DoubleLiteral(1.0))), + new Alias(random3)); + LogicalProject project = new LogicalProject(projections, studentOlapScan); + Optional, LogicalProject>> result = new AddProjectForUniqueFunction() + .rewriteExpressions(project, project.getProjects()); + Assertions.assertEquals(Optional.empty(), result); + } + + @Test + void testRewriteExpressionProjectSucc() { + Random random1 = new Random(); + Random random2 = new Random(); + List projections = ImmutableList.of( + new Alias(new Add(random1, new Add(new DoubleLiteral(1.0), new DoubleLiteral(1.0)))), + new Alias(new Add(random2, new DoubleLiteral(1.0))), + new Alias(random2)); + LogicalProject project = new LogicalProject(projections, studentOlapScan); + Optional, LogicalProject>> result = new AddProjectForUniqueFunction() + .rewriteExpressions(project, project.getProjects()); + Assertions.assertTrue(result.isPresent()); + Assertions.assertInstanceOf(LogicalProject.class, result.get().second); + LogicalProject bottomProject = (LogicalProject) result.get().second; + List bottomProjections = bottomProject.getProjects(); + Assertions.assertEquals(studentOlapScan.getOutput().size() + 1, bottomProjections.size()); + Assertions.assertEquals(studentOlapScan.getOutput(), bottomProjections.subList(0, studentOlapScan.getOutput().size())); + Alias alis = (Alias) bottomProjections.get(bottomProjections.size() - 1); + Assertions.assertEquals(alis.child(), random2); + List expectedTopProjections = ImmutableList.of( + projections.get(0), + new Alias(projections.get(1).getExprId(), new Add(alis.toSlot(), new DoubleLiteral(1.0))), + new Alias(projections.get(2).getExprId(), alis.toSlot()) + ); + Assertions.assertEquals(expectedTopProjections, result.get().first); + } + + @Test + void testRewriteJoin() { + LogicalOlapScan scoreOlapScan + = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.score); + SlotReference sid = (SlotReference) scoreOlapScan.getOutput().get(0); + Random random = new Random(); + LogicalJoin join = new LogicalJoin(JoinType.CROSS_JOIN, + ImmutableList.of(), + ImmutableList.of(new EqualTo(random, sid)), + ImmutableList.of(new EqualTo(random, new DoubleLiteral(1.0))), + new DistributeHint(DistributeType.NONE), + Optional.empty(), + ImmutableList.of(studentOlapScan, scoreOlapScan), + null); + + Plan root = PlanChecker.from(MemoTestUtils.createConnectContext(), join) + .applyTopDown(new AddProjectForUniqueFunction()) + .getPlan(); + Assertions.assertInstanceOf(LogicalJoin.class, root); + LogicalJoin newJoin = (LogicalJoin) root; + Assertions.assertInstanceOf(LogicalProject.class, newJoin.left()); + LogicalProject leftProject = (LogicalProject) newJoin.left(); + Assertions.assertEquals(studentOlapScan, leftProject.child()); + Assertions.assertEquals(scoreOlapScan, newJoin.right()); + Alias alias = (Alias) leftProject.getProjects().get(leftProject.getProjects().size() - 1); + Assertions.assertEquals(alias.child(), random); + Assertions.assertEquals(ImmutableList.of(new EqualTo(alias.toSlot(), sid)), newJoin.getHashJoinConjuncts()); + Assertions.assertEquals(ImmutableList.of(), newJoin.getOtherJoinConjuncts()); + Assertions.assertEquals(ImmutableList.of(new EqualTo(alias.toSlot(), new DoubleLiteral(1.0))), newJoin.getMarkJoinConjuncts()); + Assertions.assertEquals(JoinType.INNER_JOIN, newJoin.getJoinType()); + } +} From 8148cb92e8e3d7f3e5c10497efea9da027e9b045 Mon Sep 17 00:00:00 2001 From: yujun Date: Tue, 2 Sep 2025 11:30:49 +0800 Subject: [PATCH 4/5] fix check style --- .../rules/rewrite/AddProjectForUniqueFunction.java | 12 ++++++++---- .../rewrite/AddProjectForUniqueFunctionTest.java | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java index eb9fb2deb61116..ac645adff36e60 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java @@ -208,9 +208,11 @@ public Rule build() { } } - // extract unique function which exist multiple times from targets, - // then alias the unique function and put them into a child project, - // then rewrite targets with the alias names. + /** + * extract unique function which exist multiple times from targets, + * then alias the unique function and put them into a child project, + * then rewrite targets with the alias names. + */ @VisibleForTesting public Optional, LogicalProject>> rewriteExpressions( LogicalPlan plan, Collection targets) { @@ -236,7 +238,9 @@ public Optional, LogicalProject>> rewr return Optional.of(Pair.of(newTargetsBuilder.build(), new LogicalProject<>(projects, plan.child(0)))); } - // if a unique function exists multiple times in the targets, then add a project to alias it. + /** + * if a unique function exists multiple times in the targets, then add a project to alias it. + */ @VisibleForTesting public List tryGenUniqueFunctionAlias(Collection targets) { Map unqiueFunctionCounter = Maps.newLinkedHashMap(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunctionTest.java index 1e45ce36f9b260..715c4c3f1c5a69 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunctionTest.java @@ -56,7 +56,7 @@ void testGenUniqueFunctionAlias() { Random random2 = new Random(); Random random3 = new Random(); List expressions = ImmutableList.of( - new Add(random1, new Add(random1, new DoubleLiteral(1.0))) , + new Add(random1, new Add(random1, new DoubleLiteral(1.0))), new Add(random2, random3), random3); List namedExpressions = new AddProjectForUniqueFunction().tryGenUniqueFunctionAlias(expressions); From 4b43e8f6a08e55003fd74aaf834bb91af0cccb0c Mon Sep 17 00:00:00 2001 From: yujun Date: Tue, 2 Sep 2025 18:17:55 +0800 Subject: [PATCH 5/5] fix union and aggregate --- .../rewrite/AddProjectForUniqueFunction.java | 31 ++++ .../rewrite/MergeOneRowRelationIntoUnion.java | 11 +- .../doris/nereids/util/ExpressionUtils.java | 14 ++ .../add_project_for_unique_function.out | 145 +++++++++++++++++- .../add_project_for_unique_function.groovy | 82 +++++++++- 5 files changed, 277 insertions(+), 6 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java index ac645adff36e60..2c2537ac476d0a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.UniqueFunction; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate; import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; @@ -67,6 +68,7 @@ public List buildRules() { new ProjectRewrite().build(), new FilterRewrite().build(), new HavingRewrite().build(), + new AggregateRewrite().build(), new JoinRewrite().build() ); } @@ -165,6 +167,35 @@ public Rule build() { } } + private class AggregateRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalAggregate().thenApply(ctx -> { + LogicalAggregate aggregate = ctx.root; + List targets = Lists.newArrayList(); + targets.addAll(aggregate.getGroupByExpressions()); + targets.addAll(aggregate.getOutputExpressions()); + Optional, LogicalProject>> rewrittenOpt + = rewriteExpressions(aggregate, targets); + if (!rewrittenOpt.isPresent()) { + return aggregate; + } + + LogicalProject newChild = rewrittenOpt.get().second; + List newTargets = rewrittenOpt.get().first; + int groupBySize = aggregate.getGroupByExpressions().size(); + ImmutableList newGroupBy = ImmutableList.copyOf( + newTargets.subList(0, groupBySize)); + ImmutableList.Builder newOutputBuilder + = ImmutableList.builderWithExpectedSize(aggregate.getOutputExpressions().size()); + for (int i = groupBySize; i < newTargets.size(); i++) { + newOutputBuilder.add((NamedExpression) newTargets.get(i)); + } + return aggregate.withChildGroupByAndOutput(newGroupBy, newOutputBuilder.build(), newChild); + }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION); + } + } + private class JoinRewrite extends OneRewriteRuleFactory { @Override public Rule build() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java index 06341a96038d59..abda37ff0b407b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.collect.ImmutableList; @@ -45,7 +46,11 @@ public Rule build() { ImmutableList.Builder> newChildrenOutputs = ImmutableList.builder(); for (int i = 0; i < u.arity(); i++) { Plan child = u.child(i); - if (!(child instanceof LogicalOneRowRelation)) { + // if one row relation contains unique function which exist multiple times, + // don't merge it, later AddProjectForUniqueFunction will handle this one row relation. + if (!(child instanceof LogicalOneRowRelation) + || ExpressionUtils.containUniqueFunctionExistMultiple( + ((LogicalOneRowRelation) child).getProjects())) { newChildren.add(child); newChildrenOutputs.add(u.getRegularChildOutput(i)); } else { @@ -64,6 +69,10 @@ public Rule build() { constantExprsList.add(constantExprs.build()); } } + // no change + if (newChildren.size() == u.arity()) { + return u; + } return u.withChildrenAndConstExprsList(newChildren, newChildrenOutputs.build(), constantExprsList.build()); }).toRule(RuleType.MERGE_ONE_ROW_RELATION_INTO_UNION); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 4230240b82f2ca..cd04627ee30d71 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -1254,4 +1254,18 @@ public static boolean hasNonWindowAggregateFunction(Collection expressions) { + Set counterSet = Sets.newHashSet(); + for (Expression expression : expressions) { + if (expression.anyMatch( + expr -> expr instanceof UniqueFunction && !counterSet.add((UniqueFunction) expr))) { + return true; + } + } + return false; + } } diff --git a/regression-test/data/nereids_rules_p0/unique_function/add_project_for_unique_function.out b/regression-test/data/nereids_rules_p0/unique_function/add_project_for_unique_function.out index 2abcf9841a6ce8..456ee6b09bfed1 100644 --- a/regression-test/data/nereids_rules_p0/unique_function/add_project_for_unique_function.out +++ b/regression-test/data/nereids_rules_p0/unique_function/add_project_for_unique_function.out @@ -24,19 +24,160 @@ PhysicalResultSink ----PhysicalProject[random(1, 100) AS `$_random_5_$`, t1.id] ------PhysicalOlapScan[t1] --- !select_1 -- +-- !filter_1 -- PhysicalResultSink --PhysicalProject[t1.id] ----filter(((cast(id as BIGINT) + random(1, 100)) >= 10)) ------PhysicalOlapScan[t1] --- !select_2 -- +-- !filter_2 -- PhysicalResultSink --PhysicalProject[t1.id] ----filter(((cast(id as BIGINT) + $_random_3_$) <= 20) and ((cast(id as BIGINT) + $_random_3_$) >= 10)) ------PhysicalProject[random(1, 100) AS `$_random_3_$`, t1.id] --------PhysicalOlapScan[t1] +-- !union_1 -- +PhysicalResultSink +--hashAgg[GLOBAL, groupByExpr=(k), outputExpr=(k)] +----hashAgg[LOCAL, groupByExpr=(k), outputExpr=(k)] +------PhysicalUnion(constantExprsList=[[TRUE AS `true`]]) +--------PhysicalProject[AND[($_random_8_$ >= 0.1),($_random_8_$ <= 0.5)] AS `k`] +----------PhysicalOneRowRelation[random() AS `$_random_8_$`] + +-- !union_2 -- +PhysicalResultSink +--hashAgg[GLOBAL, groupByExpr=(k), outputExpr=(k)] +----hashAgg[LOCAL, groupByExpr=(k), outputExpr=(k)] +------PhysicalUnion(constantExprsList=[[TRUE AS `true`]]) +--------PhysicalProject[AND[((cast(id as DOUBLE) + $_random_9_$) >= 0.1),((cast(id as DOUBLE) + $_random_9_$) <= 0.5)] AS `k`] +----------PhysicalProject[random() AS `$_random_9_$`, t1.id] +------------PhysicalOlapScan[t1] + +-- !union_all_1 -- +PhysicalResultSink +--PhysicalUnion(constantExprsList=[[TRUE AS `true`]]) +----PhysicalProject[AND[($_random_6_$ >= 0.1),($_random_6_$ <= 0.5)] AS `k`] +------PhysicalOneRowRelation[random() AS `$_random_6_$`] + +-- !union_all_2 -- +PhysicalResultSink +--PhysicalUnion(constantExprsList=[[TRUE AS `true`]]) +----PhysicalProject[AND[((cast(id as DOUBLE) + $_random_7_$) >= 0.1),((cast(id as DOUBLE) + $_random_7_$) <= 0.5)] AS `k`] +------PhysicalProject[random() AS `$_random_7_$`, t1.id] +--------PhysicalOlapScan[t1] + +-- !intersect_1 -- +PhysicalResultSink +--PhysicalIntersect +----PhysicalProject[AND[($_random_8_$ >= 0.1),($_random_8_$ <= 0.5)] AS `k`] +------filter((AND[($_random_8_$ >= 0.1),($_random_8_$ <= 0.5)] = TRUE)) +--------PhysicalOneRowRelation[random() AS `$_random_8_$`] +----PhysicalOneRowRelation[TRUE AS `true`] + +-- !intersect_2 -- +PhysicalResultSink +--PhysicalIntersect +----PhysicalProject[AND[((cast(id as DOUBLE) + $_random_9_$) >= 0.1),((cast(id as DOUBLE) + $_random_9_$) <= 0.5)] AS `k`] +------filter((AND[((cast(id as DOUBLE) + $_random_9_$) >= 0.1),((cast(id as DOUBLE) + $_random_9_$) <= 0.5)] = TRUE)) +--------PhysicalProject[random() AS `$_random_9_$`, t1.id] +----------PhysicalOlapScan[t1] +----PhysicalOneRowRelation[TRUE AS `true`] + +-- !except_1 -- +PhysicalResultSink +--PhysicalExcept +----PhysicalProject[AND[($_random_8_$ >= 0.1),($_random_8_$ <= 0.5)] AS `k`] +------PhysicalOneRowRelation[random() AS `$_random_8_$`] +----PhysicalOneRowRelation[TRUE AS `true`] + +-- !except_2 -- +PhysicalResultSink +--PhysicalExcept +----PhysicalProject[AND[((cast(id as DOUBLE) + $_random_9_$) >= 0.1),((cast(id as DOUBLE) + $_random_9_$) <= 0.5)] AS `k`] +------PhysicalProject[random() AS `$_random_9_$`, t1.id] +--------PhysicalOlapScan[t1] +----PhysicalOneRowRelation[TRUE AS `true`] + +-- !sort_1 -- +PhysicalResultSink +--PhysicalProject[t.k] +----PhysicalQuickSort[MERGE_SORT, orderKeys=(AND[((cast(k as DOUBLE) + random(100)) >= 0.6),((cast(k as DOUBLE) + random(100)) <= 0.7)] asc null first)] +------PhysicalQuickSort[LOCAL_SORT, orderKeys=(AND[((cast(k as DOUBLE) + random(100)) >= 0.6),((cast(k as DOUBLE) + random(100)) <= 0.7)] asc null first)] +--------PhysicalProject[AND[((cast(k as DOUBLE) + $_random_5_$) >= 0.6),((cast(k as DOUBLE) + $_random_5_$) <= 0.7)] AS `AND[((cast(k as DOUBLE) + random(100)) >= 0.6),((cast(k as DOUBLE) + random(100)) <= 0.7)]`, t.k] +----------PhysicalProject[AND[($_random_6_$ >= 0.1),($_random_6_$ <= 0.5)] AS `k`, random(100) AS `$_random_5_$`] +------------PhysicalOneRowRelation[random() AS `$_random_6_$`] + +-- !sort_2 -- +PhysicalResultSink +--PhysicalProject[t.k] +----PhysicalQuickSort[MERGE_SORT, orderKeys=(AND[((cast(k as DOUBLE) + random(100)) >= 0.6),((cast(k as DOUBLE) + random(100)) <= 0.7)] asc null first)] +------PhysicalQuickSort[LOCAL_SORT, orderKeys=(AND[((cast(k as DOUBLE) + random(100)) >= 0.6),((cast(k as DOUBLE) + random(100)) <= 0.7)] asc null first)] +--------PhysicalProject[AND[((cast(k as DOUBLE) + $_random_6_$) >= 0.6),((cast(k as DOUBLE) + $_random_6_$) <= 0.7)] AS `AND[((cast(k as DOUBLE) + random(100)) >= 0.6),((cast(k as DOUBLE) + random(100)) <= 0.7)]`, t.k] +----------PhysicalProject[AND[((cast(id as DOUBLE) + $_random_7_$) >= 0.1),((cast(id as DOUBLE) + $_random_7_$) <= 0.5)] AS `k`, random(100) AS `$_random_6_$`] +------------PhysicalProject[random() AS `$_random_7_$`, t1.id] +--------------PhysicalOlapScan[t1] + +-- !agg_1 -- +PhysicalResultSink +--hashAgg[GLOBAL, groupByExpr=(), outputExpr=(sum(cast(AND[($_random_5_$ >= 0.6),($_random_5_$ <= 0.7)] as TINYINT)) AS `sum(random(100) between 0.6 and 0.7)`)] +----hashAgg[LOCAL, groupByExpr=(), outputExpr=(partial_sum(cast(AND[($_random_5_$ >= 0.6),($_random_5_$ <= 0.7)] as TINYINT)) AS `partial_sum(cast(AND[($_random_5_$ >= 0.6),($_random_5_$ <= 0.7)] as TINYINT))`)] +------PhysicalProject[$_random_5_$, cast(AND[($_random_5_$ >= 0.6),($_random_5_$ <= 0.7)] as TINYINT) AS `cast(AND[($_random_5_$ >= 0.6),($_random_5_$ <= 0.7)] as TINYINT)`] +--------PhysicalOneRowRelation[random(100) AS `$_random_5_$`] + +-- !agg_2 -- +PhysicalResultSink +--hashAgg[GLOBAL, groupByExpr=(), outputExpr=(sum(cast(AND[($_random_5_$ >= 0.6),($_random_5_$ <= 0.7)] as TINYINT)) AS `sum(random(100) between 0.6 and 0.7)`, sum(id) AS `sum(id)`)] +----hashAgg[LOCAL, groupByExpr=(), outputExpr=(partial_sum(cast(AND[($_random_5_$ >= 0.6),($_random_5_$ <= 0.7)] as TINYINT)) AS `partial_sum(cast(AND[($_random_5_$ >= 0.6),($_random_5_$ <= 0.7)] as TINYINT))`, partial_sum(id) AS `partial_sum(id)`)] +------PhysicalProject[cast(AND[(random(100) >= 0.6),(random(100) <= 0.7)] as TINYINT) AS `cast(AND[($_random_5_$ >= 0.6),($_random_5_$ <= 0.7)] as TINYINT)`, random(100) AS `$_random_5_$`, t1.id] +--------PhysicalOlapScan[t1] + +-- !agg_3 -- +PhysicalResultSink +--PhysicalProject[sum(id), sum(random(100) between 0.6 and 0.7)] +----hashAgg[GLOBAL, groupByExpr=(AND[(random() >= cast(0.1 as DOUBLE)),(random() <= cast(0.5 as DOUBLE))]), outputExpr=(AND[(random() >= cast(0.1 as DOUBLE)),(random() <= cast(0.5 as DOUBLE))], sum(cast(AND[($_random_7_$ >= 0.6),($_random_7_$ <= 0.7)] as TINYINT)) AS `sum(random(100) between 0.6 and 0.7)`, sum(id) AS `sum(id)`)] +------hashAgg[LOCAL, groupByExpr=(AND[(random() >= cast(0.1 as DOUBLE)),(random() <= cast(0.5 as DOUBLE))]), outputExpr=(AND[(random() >= cast(0.1 as DOUBLE)),(random() <= cast(0.5 as DOUBLE))], partial_sum(cast(AND[($_random_7_$ >= 0.6),($_random_7_$ <= 0.7)] as TINYINT)) AS `partial_sum(cast(AND[($_random_7_$ >= 0.6),($_random_7_$ <= 0.7)] as TINYINT))`, partial_sum(id) AS `partial_sum(id)`)] +--------PhysicalProject[AND[($_random_8_$ >= 0.1),($_random_8_$ <= 0.5)] AS `AND[(random() >= cast(0.1 as DOUBLE)),(random() <= cast(0.5 as DOUBLE))]`, cast(AND[(random(100) >= 0.6),(random(100) <= 0.7)] as TINYINT) AS `cast(AND[($_random_7_$ >= 0.6),($_random_7_$ <= 0.7)] as TINYINT)`, random(100) AS `$_random_7_$`, t1.id] +----------PhysicalProject[random() AS `$_random_8_$`, t1.id] +------------PhysicalOlapScan[t1] + +-- !agg_4 -- +PhysicalResultSink +--PhysicalProject[sum(id), sum(random(100) between 0.6 and 0.7)] +----hashAgg[GLOBAL, groupByExpr=(AND[((cast(id as DOUBLE) + random()) >= cast(0.1 as DOUBLE)),((cast(id as DOUBLE) + random()) <= cast(0.5 as DOUBLE))]), outputExpr=(AND[((cast(id as DOUBLE) + random()) >= cast(0.1 as DOUBLE)),((cast(id as DOUBLE) + random()) <= cast(0.5 as DOUBLE))], sum(cast(AND[($_random_7_$ >= 0.6),($_random_7_$ <= 0.7)] as TINYINT)) AS `sum(random(100) between 0.6 and 0.7)`, sum(id) AS `sum(id)`)] +------hashAgg[LOCAL, groupByExpr=(AND[((cast(id as DOUBLE) + random()) >= cast(0.1 as DOUBLE)),((cast(id as DOUBLE) + random()) <= cast(0.5 as DOUBLE))]), outputExpr=(AND[((cast(id as DOUBLE) + random()) >= cast(0.1 as DOUBLE)),((cast(id as DOUBLE) + random()) <= cast(0.5 as DOUBLE))], partial_sum(cast(AND[($_random_7_$ >= 0.6),($_random_7_$ <= 0.7)] as TINYINT)) AS `partial_sum(cast(AND[($_random_7_$ >= 0.6),($_random_7_$ <= 0.7)] as TINYINT))`, partial_sum(id) AS `partial_sum(id)`)] +--------PhysicalProject[AND[((cast(id as DOUBLE) + $_random_8_$) >= 0.1),((cast(id as DOUBLE) + $_random_8_$) <= 0.5)] AS `AND[((cast(id as DOUBLE) + random()) >= cast(0.1 as DOUBLE)),((cast(id as DOUBLE) + random()) <= cast(0.5 as DOUBLE))]`, cast(AND[(random(100) >= 0.6),(random(100) <= 0.7)] as TINYINT) AS `cast(AND[($_random_7_$ >= 0.6),($_random_7_$ <= 0.7)] as TINYINT)`, random(100) AS `$_random_7_$`, t1.id] +----------PhysicalProject[random() AS `$_random_8_$`, t1.id] +------------PhysicalOlapScan[t1] + +-- !window_1 -- +PhysicalResultSink +--PhysicalProject[sum(random(1) between 0.1 and 0.11) + over(partition by random(2) between 0.2 and 0.22)] +----PhysicalWindow +------PhysicalQuickSort[LOCAL_SORT, orderKeys=(AND[(random(2) >= 0.2),(random(2) <= 0.22)] asc)] +--------PhysicalProject[AND[($_random_8_$ >= 0.2),($_random_8_$ <= 0.22)] AS `AND[(random(2) >= 0.2),(random(2) <= 0.22)]`, cast(AND[($_random_7_$ >= 0.1),($_random_7_$ <= 0.11)] as TINYINT) AS `cast(AND[(random(1) >= 0.1),(random(1) <= 0.11)] as TINYINT)`] +----------PhysicalOneRowRelation[random(1) AS `$_random_7_$`, random(2) AS `$_random_8_$`] + +-- !window_2 -- +PhysicalResultSink +--PhysicalProject[sum(random(1) between 0.1 and 0.11) + over(partition by random(2) between 0.2 and 0.22 order by random(3) between 0.3 and 0.33)] +----PhysicalWindow +------PhysicalQuickSort[LOCAL_SORT, orderKeys=(AND[(random(2) >= 0.2),(random(2) <= 0.22)] asc)] +--------PhysicalProject[AND[($_random_10_$ >= 0.2),($_random_10_$ <= 0.22)] AS `AND[(random(2) >= 0.2),(random(2) <= 0.22)]`, cast(AND[($_random_9_$ >= 0.1),($_random_9_$ <= 0.11)] as TINYINT) AS `cast(AND[(random(1) >= 0.1),(random(1) <= 0.11)] as TINYINT)`] +----------PhysicalOneRowRelation[random(1) AS `$_random_9_$`, random(2) AS `$_random_10_$`] + +-- !window_3 -- +PhysicalResultSink +--PhysicalProject[sum(id + random(1) between 0.1 and 0.11) + over(partition by id + random(2) between 0.2 and 0.22 order by id + random(3) between 0.3 and 0.33)] +----PhysicalWindow +------PhysicalQuickSort[LOCAL_SORT, orderKeys=(AND[((cast(id as DOUBLE) + random(2)) >= 0.2),((cast(id as DOUBLE) + random(2)) <= 0.22)] asc, AND[((cast(id as DOUBLE) + random(3)) >= 0.3),((cast(id as DOUBLE) + random(3)) <= 0.33)] asc null first)] +--------PhysicalProject[AND[((cast(id as DOUBLE) + $_random_10_$) >= 0.2),((cast(id as DOUBLE) + $_random_10_$) <= 0.22)] AS `AND[((cast(id as DOUBLE) + random(2)) >= 0.2),((cast(id as DOUBLE) + random(2)) <= 0.22)]`, AND[((cast(id as DOUBLE) + $_random_11_$) >= 0.3),((cast(id as DOUBLE) + $_random_11_$) <= 0.33)] AS `AND[((cast(id as DOUBLE) + random(3)) >= 0.3),((cast(id as DOUBLE) + random(3)) <= 0.33)]`, cast(AND[((cast(id as DOUBLE) + $_random_9_$) >= 0.1),((cast(id as DOUBLE) + $_random_9_$) <= 0.11)] as TINYINT) AS `cast(AND[((cast(id as DOUBLE) + random(1)) >= 0.1),((cast(id as DOUBLE) + random(1)) <= 0.11)] as TINYINT)`] +----------PhysicalProject[random(1) AS `$_random_9_$`, random(2) AS `$_random_10_$`, random(3) AS `$_random_11_$`, t1.id] +------------PhysicalOlapScan[t1] + -- !join_1 -- PhysicalResultSink --PhysicalProject[t1.id, t1.msg, t2.id, t2.msg] diff --git a/regression-test/suites/nereids_rules_p0/unique_function/add_project_for_unique_function.groovy b/regression-test/suites/nereids_rules_p0/unique_function/add_project_for_unique_function.groovy index d27f9c57ef6cdd..9a400900ca3792 100644 --- a/regression-test/suites/nereids_rules_p0/unique_function/add_project_for_unique_function.groovy +++ b/regression-test/suites/nereids_rules_p0/unique_function/add_project_for_unique_function.groovy @@ -20,7 +20,7 @@ suite('add_project_for_unique_function') { sql 'SET runtime_filter_mode=OFF' sql 'SET enable_fallback_to_original_planner=false' sql "SET ignore_shape_nodes='PhysicalDistribute'" - sql "SET detail_shape_nodes='PhysicalProject,PhysicalOneRowRelation'" + sql "SET detail_shape_nodes='PhysicalProject,PhysicalOneRowRelation,PhysicalUnion,PhysicalQuickSort,PhysicalHashAggregate'" sql 'SET disable_nereids_rules=PRUNE_EMPTY_PARTITION' // no project @@ -44,14 +44,90 @@ suite('add_project_for_unique_function') { explain shape plan select id + random(1, 100) between 10 and 20, id * 200 from t1 ''' - qt_select_1 ''' + qt_filter_1 ''' explain shape plan select id from t1 where id + random(1, 100) >= 10 ''' - qt_select_2 ''' + qt_filter_2 ''' explain shape plan select id from t1 where id + random(1, 100) between 10 and 20 ''' + qt_union_1 ''' + explain shape plan select (random() between 0.1 and 0.5) as k union select true + ''' + + qt_union_2 ''' + explain shape plan select (id + random() between 0.1 and 0.5) as k from t1 union select true + ''' + + qt_union_all_1 ''' + explain shape plan select (random() between 0.1 and 0.5) as k union all select true + ''' + + qt_union_all_2 ''' + explain shape plan select (id + random() between 0.1 and 0.5) as k from t1 union all select true + ''' + + qt_intersect_1 ''' + explain shape plan select (random() between 0.1 and 0.5) as k intersect select true + ''' + + qt_intersect_2 ''' + explain shape plan select (id + random() between 0.1 and 0.5) as k from t1 intersect select true + ''' + + qt_except_1 ''' + explain shape plan select (random() between 0.1 and 0.5) as k except select true + ''' + + qt_except_2 ''' + explain shape plan select (id + random() between 0.1 and 0.5) as k from t1 except select true + ''' + + qt_sort_1 ''' + explain shape plan select * from (select (random() between 0.1 and 0.5) as k) t + order by k + random(100) between 0.6 and 0.7 + ''' + + qt_sort_2 ''' + explain shape plan select * from (select (id + random() between 0.1 and 0.5) as k from t1) t + order by k + random(100) between 0.6 and 0.7 + ''' + + qt_agg_1 ''' + explain shape plan select sum(random(100) between 0.6 and 0.7) + ''' + + qt_agg_2 ''' + explain shape plan select sum(id), sum(random(100) between 0.6 and 0.7) from t1 + ''' + + qt_agg_3 ''' + explain shape plan select sum(id), sum(random(100) between 0.6 and 0.7) from t1 + group by random() between 0.1 and 0.5 + ''' + + qt_agg_4 ''' + explain shape plan select sum(id), sum(random(100) between 0.6 and 0.7) from t1 + group by id + random() between 0.1 and 0.5 + ''' + + qt_window_1 ''' + explain shape plan select sum(random(1) between 0.1 and 0.11) + over(partition by random(2) between 0.2 and 0.22) + ''' + + qt_window_2 ''' + explain shape plan select sum(random(1) between 0.1 and 0.11) + over(partition by random(2) between 0.2 and 0.22 order by random(3) between 0.3 and 0.33) + ''' + + qt_window_3 ''' + explain shape plan select sum(id + random(1) between 0.1 and 0.11) + over(partition by id + random(2) between 0.2 and 0.22 order by id + random(3) between 0.3 and 0.33) + from t1 + ''' + qt_join_1 ''' explain shape plan select * from t1 join t2 on t1.id + t2.id + random(1, 100) between 10 and 20