From 63e742c9fef7b6fe62b17c828e39b79a567418e7 Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Mon, 21 Aug 2023 18:07:35 +0800 Subject: [PATCH 1/4] [feature](nereids)support subquery in select list --- .../doris/nereids/jobs/executor/Analyzer.java | 4 +- .../doris/nereids/jobs/executor/Rewriter.java | 7 +- .../NormalizeAggregate.java | 4 +- .../rules/analysis/SubqueryToApply.java | 86 +++++++------ .../implementation/AggregateStrategies.java | 2 +- .../rewrite/PushdownAliasThroughJoin.java | 5 +- .../nereids/trees/expressions/CaseWhen.java | 2 +- .../expressions/literal/DoubleLiteral.java | 5 + .../expressions/literal/FloatLiteral.java | 9 ++ .../analysis/AnalyzeWhereSubqueryTest.java | 70 +++++----- .../rewrite/AggregateStrategiesTest.java | 1 + ...tractAndNormalizeWindowExpressionTest.java | 1 + .../rules/rewrite/NormalizeAggregateTest.java | 1 + .../subquery/test_subquery_in_project.out | 50 ++++++++ .../subquery/test_subquery_in_project.groovy | 120 ++++++++++++++++++ 15 files changed, 291 insertions(+), 76 deletions(-) rename fe/fe-core/src/main/java/org/apache/doris/nereids/rules/{rewrite => analysis}/NormalizeAggregate.java (98%) create mode 100644 regression-test/data/nereids_p0/subquery/test_subquery_in_project.out create mode 100644 regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java index 1fb1d7eecdf080..c28abb120b48f6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java @@ -31,6 +31,7 @@ import org.apache.doris.nereids.rules.analysis.CheckBound; import org.apache.doris.nereids.rules.analysis.CheckPolicy; import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.analysis.NormalizeRepeat; import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate; import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate; @@ -110,8 +111,9 @@ private static List buildAnalyzeJobs(Optional c // LogicalProject for normalize. This rule depends on FillUpMissingSlots to fill up slots. new NormalizeRepeat() ), - bottomUp(new SubqueryToApply()), bottomUp(new AdjustAggregateNullableForEmptySet()), + topDown(new NormalizeAggregate()), + bottomUp(new SubqueryToApply()), bottomUp(new CheckAnalysis()) ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index ec2ea06eace9f0..f188a309a86146 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount; import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite; import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite; import org.apache.doris.nereids.rules.expression.ExpressionNormalization; import org.apache.doris.nereids.rules.expression.ExpressionOptimization; @@ -74,13 +75,13 @@ import org.apache.doris.nereids.rules.rewrite.MergeOneRowRelationIntoUnion; import org.apache.doris.nereids.rules.rewrite.MergeProjects; import org.apache.doris.nereids.rules.rewrite.MergeSetOperations; -import org.apache.doris.nereids.rules.rewrite.NormalizeAggregate; import org.apache.doris.nereids.rules.rewrite.NormalizeSort; import org.apache.doris.nereids.rules.rewrite.PruneFileScanPartition; import org.apache.doris.nereids.rules.rewrite.PruneOlapScanPartition; import org.apache.doris.nereids.rules.rewrite.PruneOlapScanTablet; import org.apache.doris.nereids.rules.rewrite.PullUpCteAnchor; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan; +import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderApply; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan; import org.apache.doris.nereids.rules.rewrite.PushFilterInsideJoin; import org.apache.doris.nereids.rules.rewrite.PushProjectIntoOneRowRelation; @@ -139,6 +140,10 @@ public class Rewriter extends AbstractBatchJobExecutor { ), // subquery unnesting relay on ExpressionNormalization to extract common factor expression topic("Subquery unnesting", + // after moving NormalizeAggregate into analysis job + // we need run the following 3 rules before subquery unnesting + bottomUp(new PullUpProjectUnderApply(), new MergeProjects()), + topDown(new PushdownFilterThroughProject()), costBased( custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION, AggScalarSubQueryToWindowFunction::new) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java similarity index 98% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index eb683e8b5835db..6a141dce7a795c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.rewrite; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java index 6dfe95c1167b2b..388c4eebba8db1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java @@ -21,9 +21,7 @@ import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.BinaryOperator; -import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Exists; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InSubquery; @@ -44,6 +42,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import java.util.Collection; import java.util.LinkedHashMap; @@ -116,36 +115,45 @@ public List buildRules() { return new LogicalFilter<>(conjuncts, applyPlan); }) ), - RuleType.PROJECT_SUBQUERY_TO_APPLY.build( - logicalProject().thenApply(ctx -> { - LogicalProject project = ctx.root; - Set subqueryExprs = new LinkedHashSet<>(); - project.getProjects().stream() - .filter(Alias.class::isInstance) - .map(Alias.class::cast) - .filter(alias -> alias.child() instanceof CaseWhen) - .forEach(alias -> alias.child().children().stream() - .forEach(e -> - subqueryExprs.addAll(e.collect(SubqueryExpr.class::isInstance)))); - if (subqueryExprs.isEmpty()) { - return project; - } - - SubqueryContext context = new SubqueryContext(subqueryExprs); - return new LogicalProject(project.getProjects().stream() - .map(p -> p.withChildren( - new ReplaceSubquery(ctx.statementContext, true) - .replace(p, context))) - .collect(ImmutableList.toImmutableList()), - subqueryToApply( - subqueryExprs.stream().collect(ImmutableList.toImmutableList()), - (LogicalPlan) project.child(), - context.getSubqueryToMarkJoinSlot(), - ctx.cascadesContext, - Optional.empty(), true - )); - }) - ) + RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> { + LogicalProject project = ctx.root; + ImmutableList subqueryExprsList = project.getProjects().stream() + .map(e -> (Set) e.collect(SubqueryExpr.class::isInstance)) + .collect(ImmutableList.toImmutableList()); + if (subqueryExprsList.stream().flatMap(Collection::stream) + .noneMatch(SubqueryExpr.class::isInstance)) { + return project; + } + List oldProjects = ImmutableList.copyOf(project.getProjects()); + List newProjects = Lists.newArrayList(); + LogicalPlan childPlan = (LogicalPlan) project.child(); + LogicalPlan applyPlan; + for (int i = 0; i < subqueryExprsList.size(); ++i) { + Set subqueryExprs = subqueryExprsList.get(i); + if (subqueryExprs.isEmpty()) { + newProjects.add(oldProjects.get(i)); + continue; + } + + // first step: Replace the subquery in logcialProject's project list + // second step: Replace subquery with LogicalApply + ReplaceSubquery replaceSubquery = + new ReplaceSubquery(ctx.statementContext, true); + SubqueryContext context = new SubqueryContext(subqueryExprs); + Expression newProject = + replaceSubquery.replace(oldProjects.get(i), context); + + applyPlan = subqueryToApply( + subqueryExprs.stream().collect(ImmutableList.toImmutableList()), + childPlan, context.getSubqueryToMarkJoinSlot(), + ctx.cascadesContext, + Optional.of(newProject), true); + childPlan = applyPlan; + newProjects.add((NamedExpression) newProject); + } + + return project.withProjectsAndChild(newProjects, childPlan); + })) ); } @@ -249,28 +257,30 @@ public Expression visitExistsSubquery(Exists exists, SubqueryContext context) { // The result set when NULL is specified in the subquery and still evaluates to TRUE by using EXISTS // When the number of rows returned is empty, agg will return null, so if there is more agg, // it will always consider the returned result to be true + boolean needCreateMarkJoinSlot = isMarkJoin || isProject; MarkJoinSlotReference markJoinSlotReference = null; - if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && isMarkJoin) { + if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && needCreateMarkJoinSlot) { markJoinSlotReference = new MarkJoinSlotReference(statementContext.generateColumnName(), true); - } else if (isMarkJoin) { + } else if (needCreateMarkJoinSlot) { markJoinSlotReference = new MarkJoinSlotReference(statementContext.generateColumnName()); } - if (isMarkJoin) { + if (needCreateMarkJoinSlot) { context.setSubqueryToMarkJoinSlot(exists, Optional.of(markJoinSlotReference)); } - return isMarkJoin ? markJoinSlotReference : BooleanLiteral.TRUE; + return needCreateMarkJoinSlot ? markJoinSlotReference : BooleanLiteral.TRUE; } @Override public Expression visitInSubquery(InSubquery in, SubqueryContext context) { MarkJoinSlotReference markJoinSlotReference = new MarkJoinSlotReference(statementContext.generateColumnName()); - if (isMarkJoin) { + boolean needCreateMarkJoinSlot = isMarkJoin || isProject; + if (needCreateMarkJoinSlot) { context.setSubqueryToMarkJoinSlot(in, Optional.of(markJoinSlotReference)); } - return isMarkJoin ? markJoinSlotReference : BooleanLiteral.TRUE; + return needCreateMarkJoinSlot ? markJoinSlotReference : BooleanLiteral.TRUE; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 03962ff75205f7..f0971c94ba6e84 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -29,8 +29,8 @@ import org.apache.doris.nereids.properties.RequireProperties; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE; -import org.apache.doris.nereids.rules.rewrite.NormalizeAggregate; import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Cast; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoin.java index 42649cbac1ead4..7839fcfe95d031 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoin.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; @@ -45,7 +46,9 @@ public class PushdownAliasThroughJoin extends OneRewriteRuleFactory { public Rule build() { return logicalProject(logicalJoin()) .when(project -> project.getProjects().stream().allMatch(expr -> - (expr instanceof Slot) || (expr instanceof Alias && ((Alias) expr).child() instanceof Slot))) + (expr instanceof Slot && !(expr instanceof MarkJoinSlotReference)) + || (expr instanceof Alias && ((Alias) expr).child() instanceof Slot + && !(((Alias) expr).child() instanceof MarkJoinSlotReference)))) .when(project -> project.getProjects().stream().anyMatch(expr -> expr instanceof Alias)) .then(project -> { LogicalJoin join = project.child(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java index c9233d5c146b70..11456e8f94c3b3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java @@ -94,7 +94,7 @@ public String toString() { StringBuilder output = new StringBuilder("CASE"); for (Expression child : children()) { if (child instanceof WhenClause) { - output.append(child); + output.append(child.toString()); } else { output.append(" ELSE ").append(child.toString()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DoubleLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DoubleLiteral.java index bdd26460c06b96..b155fe307563c8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DoubleLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DoubleLiteral.java @@ -58,4 +58,9 @@ public String toString() { nf.setGroupingUsed(false); return nf.format(value); } + + @Override + public String getStringValue() { + return toString(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/FloatLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/FloatLiteral.java index 4fff7445efae4d..95549901dda2c2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/FloatLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/FloatLiteral.java @@ -22,6 +22,8 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.FloatType; +import java.text.NumberFormat; + /** * float type literal */ @@ -48,4 +50,11 @@ public R accept(ExpressionVisitor visitor, C context) { public LiteralExpr toLegacyLiteral() { return new org.apache.doris.analysis.FloatLiteral((double) value, Type.FLOAT); } + + @Override + public String getStringValue() { + NumberFormat nf = NumberFormat.getInstance(); + nf.setGroupingUsed(false); + return nf.format(value); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java index bf060d7e5dd8a4..73422bee702400 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java @@ -156,18 +156,20 @@ public void testWhereSql2AfterAnalyzed() { .matchesNotCheck( logicalApply( any(), - logicalAggregate( - logicalFilter() - ).when(FieldChecker.check("outputExpressions", ImmutableList.of( - new Alias(new ExprId(7), - (new Sum( - new SlotReference(new ExprId(4), "k3", - BigIntType.INSTANCE, true, - ImmutableList.of( - "default_cluster:test", - "t7")))).withAlwaysNullable( - true), - "sum(k3)")))) + logicalProject( + logicalAggregate( + logicalProject() + ).when(FieldChecker.check("outputExpressions", ImmutableList.of( + new Alias(new ExprId(7), + (new Sum( + new SlotReference(new ExprId(4), "k3", + BigIntType.INSTANCE, true, + ImmutableList.of( + "default_cluster:test", + "t7")))).withAlwaysNullable( + true), + "sum(k3)")))) + ) ).when(FieldChecker.check("correlationSlot", ImmutableList.of( new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true, ImmutableList.of("default_cluster:test", "t6")) @@ -383,28 +385,32 @@ public void testSql10AfterAnalyze() { logicalProject( logicalApply( any(), - logicalAggregate( - logicalSubQueryAlias( + logicalProject( + logicalAggregate( logicalProject( - logicalFilter() - ).when(p -> p.getProjects().equals(ImmutableList.of( - new Alias(new ExprId(7), new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE, - true, - ImmutableList.of("default_cluster:test", "t7")), "aa") - ))) - ) - .when(a -> a.getAlias().equals("t2")) - .when(a -> a.getOutput().equals(ImmutableList.of( - new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, - true, ImmutableList.of("t2")) + logicalSubQueryAlias( + logicalProject( + logicalFilter() + ).when(p -> p.getProjects().equals(ImmutableList.of( + new Alias(new ExprId(7), new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE, + true, + ImmutableList.of("default_cluster:test", "t7")), "aa") + ))) + ) + .when(a -> a.getAlias().equals("t2")) + .when(a -> a.getOutput().equals(ImmutableList.of( + new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, + true, ImmutableList.of("t2")) + ))) + ) + ).when(agg -> agg.getOutputExpressions().equals(ImmutableList.of( + new Alias(new ExprId(8), + (new Max(new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, + true, + ImmutableList.of("t2")))).withAlwaysNullable(true), "max(aa)") ))) - ).when(agg -> agg.getOutputExpressions().equals(ImmutableList.of( - new Alias(new ExprId(8), - (new Max(new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, - true, - ImmutableList.of("t2")))).withAlwaysNullable(true), "max(aa)") - ))) - .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of())) + .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of())) + ) ) .when(apply -> apply.getCorrelationSlot().equals(ImmutableList.of( new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true, diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java index 6f3bfaa7e53314..34c16309181466 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.implementation.AggregateStrategies; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.AggregateExpression; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpressionTest.java index e676caa37a8c98..476131e6b068b1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpressionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpressionTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java index 32f7b324f9af47..29280e29c7c732 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; diff --git a/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out new file mode 100644 index 00000000000000..5b97935639059d --- /dev/null +++ b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out @@ -0,0 +1,50 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql1 -- +3 + +-- !sql2 -- +3 + +-- !sql3 -- +3 + +-- !sql4 -- +false + +-- !sql5 -- +false + +-- !sql6 -- +true + +-- !sql7 -- +2 + +-- !sql8 -- +4 +4 + +-- !sql9 -- +4 +4 + +-- !sql10 -- +false +true + +-- !sql11 -- +false +true + +-- !sql12 -- +true +true + +-- !sql13 -- +2 +2 + +-- !sql14 -- +\N 2.0 +2020-09-09 2.0 + diff --git a/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy new file mode 100644 index 00000000000000..0521334d8ae881 --- /dev/null +++ b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_subquery_in_project") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql """drop table if exists test_sql;""" + sql """ + CREATE TABLE `test_sql` ( + `user_id` varchar(10) NULL, + `dt` date NULL, + `city` varchar(20) NULL, + `age` int(11) NULL + ) ENGINE=OLAP + UNIQUE KEY(`user_id`) + COMMENT 'test' + DISTRIBUTED BY HASH(`user_id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "is_being_synced" = "false", + "storage_format" = "V2", + "light_schema_change" = "true", + "disable_auto_compaction" = "false", + "enable_single_replica_compaction" = "false" + ); + """ + + sql """ insert into test_sql values (1,'2020-09-09',2,3);""" + + qt_sql1 """ + select (select age from test_sql) col from test_sql order by col; + """ + + qt_sql2 """ + select (select sum(age) from test_sql) col from test_sql order by col; + """ + + qt_sql3 """ + select (select sum(age) from test_sql t2 where t2.dt = t1.dt ) col from test_sql t1 order by col; + """ + + qt_sql4 """ + select age in (select user_id from test_sql) col from test_sql order by col; + """ + + qt_sql5 """ + select age in (select user_id from test_sql t2 where t2.user_id = t1.age) col from test_sql t1 order by col; + """ + + qt_sql6 """ + select exists ( select user_id from test_sql ) col from test_sql order by col; + """ + + qt_sql7 """ + select case when age in (select user_id from test_sql) or age in (select user_id from test_sql t2 where t2.user_id = t1.age) or exists ( select user_id from test_sql ) or exists ( select t2.user_id from test_sql t2 where t2.age = t1.user_id) or age < (select sum(age) from test_sql t2 where t2.dt = t1.dt ) then 2 else 1 end col from test_sql t1 order by col; + """ + + sql """ insert into test_sql values (2,'2020-09-09',2,1);""" + + try { + sql """ + select (select age from test_sql) col from test_sql order by col; + """ + } catch (Exception ex) { + assertTrue(ex.getMessage().contains("Expected EQ 1 to be returned by expression")) + } + + qt_sql8 """ + select (select sum(age) from test_sql) col from test_sql order by col; + """ + + qt_sql9 """ + select (select sum(age) from test_sql t2 where t2.dt = t1.dt ) col from test_sql t1 order by col; + """ + + qt_sql10 """ + select age in (select user_id from test_sql) col from test_sql order by col; + """ + + qt_sql11 """ + select age in (select user_id from test_sql t2 where t2.user_id = t1.age) col from test_sql t1 order by col; + """ + + qt_sql12 """ + select exists ( select user_id from test_sql ) col from test_sql order by col; + """ + + qt_sql13 """ + select case when age in (select user_id from test_sql) or age in (select user_id from test_sql t2 where t2.user_id = t1.age) or exists ( select user_id from test_sql ) or exists ( select t2.user_id from test_sql t2 where t2.age = t1.user_id) or age < (select sum(age) from test_sql t2 where t2.dt = t1.dt ) then 2 else 1 end col from test_sql t1 order by col; + """ + + qt_sql14 """ + select dt,case when 'med'='med' then ( + select sum(midean) from ( + select sum(score) / count(*) as midean + from ( + select age score,row_number() over (order by age desc) as desc_math, + row_number() over (order by age asc) as asc_math from test_sql + ) as order_table + where asc_math in (desc_math, desc_math + 1, desc_math - 1)) m + ) + end 'test' from test_sql group by cube(dt) order by dt; + """ + + sql """drop table if exists test_sql;""" +} From ae81dc3360fae10206a4885005aec425520354ac Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Thu, 24 Aug 2023 16:52:42 +0800 Subject: [PATCH 2/4] fix failed test cases --- .../doris/nereids/jobs/executor/Analyzer.java | 6 + .../doris/nereids/jobs/executor/Rewriter.java | 10 +- .../EliminateGroupByConstant.java | 3 +- .../rules/analysis/SubqueryToApply.java | 4 +- .../rules/analysis/AnalyzeCTETest.java | 2 +- .../rules/analysis/BindSlotReferenceTest.java | 7 +- .../analysis/FillUpMissingSlotsTest.java | 248 ++++++++++-------- .../rules/rewrite/ColumnPruningTest.java | 17 +- .../rewrite/EliminateGroupByConstantTest.java | 1 + ...ushdownExpressionsInHashConditionTest.java | 34 ++- .../rewrite/mv/SelectRollupIndexTest.java | 4 +- .../shape/query1.out | 15 +- .../shape/query30.out | 15 +- .../shape/query51.out | 42 +-- .../shape/query81.out | 15 +- .../shape/q20.out | 13 +- .../nereids_tpch_shape_sf500_p0/shape/q20.out | 13 +- 17 files changed, 240 insertions(+), 209 deletions(-) rename fe/fe-core/src/main/java/org/apache/doris/nereids/rules/{rewrite => analysis}/EliminateGroupByConstant.java (96%) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java index c28abb120b48f6..f2acdc0a38a6ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.rules.analysis.CheckAnalysis; import org.apache.doris.nereids.rules.analysis.CheckBound; import org.apache.doris.nereids.rules.analysis.CheckPolicy; +import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant; import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots; import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.analysis.NormalizeRepeat; @@ -112,6 +113,11 @@ private static List buildAnalyzeJobs(Optional c new NormalizeRepeat() ), bottomUp(new AdjustAggregateNullableForEmptySet()), + // run CheckAnalysis before EliminateGroupByConstant in order to report error message correctly like bellow + // select SUM(lo_tax) FROM lineorder group by 1; + // errCode = 2, detailMessage = GROUP BY expression must not contain aggregate functions: sum(lo_tax) + bottomUp(new CheckAnalysis()), + topDown(new EliminateGroupByConstant()), topDown(new NormalizeAggregate()), bottomUp(new SubqueryToApply()), bottomUp(new CheckAnalysis()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index f188a309a86146..986947c262575e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet; import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount; import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite; +import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant; import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject; import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite; @@ -53,7 +54,6 @@ import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition; import org.apache.doris.nereids.rules.rewrite.EliminateEmptyRelation; import org.apache.doris.nereids.rules.rewrite.EliminateFilter; -import org.apache.doris.nereids.rules.rewrite.EliminateGroupByConstant; import org.apache.doris.nereids.rules.rewrite.EliminateLimit; import org.apache.doris.nereids.rules.rewrite.EliminateNotNull; import org.apache.doris.nereids.rules.rewrite.EliminateNullAwareLeftAntiJoin; @@ -80,8 +80,8 @@ import org.apache.doris.nereids.rules.rewrite.PruneOlapScanPartition; import org.apache.doris.nereids.rules.rewrite.PruneOlapScanTablet; import org.apache.doris.nereids.rules.rewrite.PullUpCteAnchor; -import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan; import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderApply; +import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan; import org.apache.doris.nereids.rules.rewrite.PushFilterInsideJoin; import org.apache.doris.nereids.rules.rewrite.PushProjectIntoOneRowRelation; @@ -140,9 +140,9 @@ public class Rewriter extends AbstractBatchJobExecutor { ), // subquery unnesting relay on ExpressionNormalization to extract common factor expression topic("Subquery unnesting", - // after moving NormalizeAggregate into analysis job - // we need run the following 3 rules before subquery unnesting - bottomUp(new PullUpProjectUnderApply(), new MergeProjects()), + // after doing NormalizeAggregate in analysis job + // we need run the following 2 rules to make AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION work + bottomUp(new PullUpProjectUnderApply()), topDown(new PushdownFilterThroughProject()), costBased( custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstant.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java similarity index 96% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstant.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java index f5a01fe530543a..e7fa14e5cb24f6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstant.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.rewrite; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.Plan; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java index 388c4eebba8db1..26a5bd10adf129 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java @@ -46,7 +46,6 @@ import java.util.Collection; import java.util.LinkedHashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -103,8 +102,7 @@ public List buildRules() { tmpPlan = applyPlan; newConjuncts.add(conjunct); } - Set conjuncts = new LinkedHashSet<>(); - conjuncts.addAll(newConjuncts.build()); + Set conjuncts = ImmutableSet.copyOf(newConjuncts.build()); Plan newFilter = new LogicalFilter<>(conjuncts, applyPlan); if (conjuncts.stream().flatMap(c -> c.children().stream()) .anyMatch(MarkJoinSlotReference.class::isInstance)) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java index 522f198e3ff774..ef5a32e2d3bcdb 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java @@ -140,7 +140,7 @@ public void testCTEInHavingAndSubquery() { logicalFilter( logicalProject( logicalJoin( - logicalAggregate(), + logicalProject(logicalAggregate()), logicalProject() ) ) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java index 0a3334b4cf0c2b..dc05ec062637ea 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java @@ -90,10 +90,11 @@ public void testGroupByOnJoin() { join ); PlanChecker checker = PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate); - LogicalAggregate plan = (LogicalAggregate) checker.getCascadesContext().getMemo().copyOut(); + LogicalAggregate plan = (LogicalAggregate) ((LogicalProject) checker.getCascadesContext() + .getMemo().copyOut()).child(); SlotReference groupByKey = (SlotReference) plan.getGroupByExpressions().get(0); - SlotReference t1id = (SlotReference) ((LogicalJoin) plan.child()).left().getOutput().get(0); - SlotReference t2id = (SlotReference) ((LogicalJoin) plan.child()).right().getOutput().get(0); + SlotReference t1id = (SlotReference) ((LogicalJoin) plan.child().child(0)).left().getOutput().get(0); + SlotReference t2id = (SlotReference) ((LogicalJoin) plan.child().child(0)).right().getOutput().get(0); Assertions.assertEquals(groupByKey.getExprId(), t1id.getExprId()); Assertions.assertNotEquals(t1id.getExprId(), t2id.getExprId()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java index 03cc549bc2c1fa..64405e3bc1c22d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java @@ -35,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.util.FieldChecker; import org.apache.doris.nereids.util.MemoPatternMatchSupported; @@ -86,35 +87,34 @@ public void testHavingGroupBySlot() { ); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1))))); + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))))); sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0"; - a1 = new SlotReference( - new ExprId(1), "a1", TinyIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") - ); - Alias value = new Alias(new ExprId(3), a1, "value"); + SlotReference value = new SlotReference(new ExprId(3), "value", TinyIntType.INSTANCE, true, ImmutableList.of()); PlanChecker.from(connectContext).analyze(sql) .applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE)) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))))); sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING value > 0"; PlanChecker.from(connectContext).analyze(sql) .applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE)) .matches( logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value)))) ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))); sql = "SELECT SUM(a2) FROM t1 GROUP BY a1 HAVING a1 > 0"; @@ -130,13 +130,14 @@ public void testHavingGroupBySlot() { PlanChecker.from(connectContext).analyze(sql) .applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE)) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(sumA2, a1))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(sumA2.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(sumA2.toSlot())))); } @Test @@ -153,13 +154,14 @@ public void testHavingAggregateFunction() { Alias sumA2 = new Alias(new ExprId(3), new Sum(a2), "sum(a2)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + )).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING SUM(a2) > 0"; sumA2 = new Alias(new ExprId(3), new Sum(a2), "SUM(a2)"); @@ -281,19 +283,21 @@ void testJoinWithHaving() { Alias sumB1 = new Alias(new ExprId(7), new Sum(b1), "sum(b1)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( + logicalProject( logicalFilter( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) - ) - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE), - sumB1.toSlot())))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); + logicalProject( + logicalAggregate( + logicalProject( + logicalFilter( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + )) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1))) + )).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE), + sumB1.toSlot())))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); } @Test @@ -348,31 +352,33 @@ void testComplexQueryWithHaving() { Alias v1 = new Alias(new ExprId(12), new Count(a2), "v1"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( + logicalProject( logicalFilter( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) + logicalProject( + logicalAggregate( + logicalProject( + logicalFilter( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + )) + )).when(FieldChecker.check("outputExpressions", + Lists.newArrayList(pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1, pk))) + ).when(FieldChecker.check("conjuncts", + ImmutableSet.of( + new GreaterThan(pk.toSlot(), Literal.of((byte) 0)), + new GreaterThan(countA11.toSlot(), Literal.of(0L)), + new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of(0L)), + new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of(0L)), + new GreaterThan(v1.toSlot(), Literal.of(0L)) + )) ) - ).when(FieldChecker.check("outputExpressions", - Lists.newArrayList(pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1, pk))) - ).when(FieldChecker.check("conjuncts", - ImmutableSet.of( - new GreaterThan(pk.toSlot(), Literal.of((byte) 0)), - new GreaterThan(countA11.toSlot(), Literal.of(0L)), - new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of(0L)), - new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of(0L)), - new GreaterThan(v1.toSlot(), Literal.of(0L)) - )) - ) - ).when(FieldChecker.check( - "projects", Lists.newArrayList( - pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1).stream() - .map(Alias::toSlot).collect(Collectors.toList())) - )); + ).when(FieldChecker.check( + "projects", Lists.newArrayList( + pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1).stream() + .map(Alias::toSlot).collect(Collectors.toList())) + )); } @Test @@ -391,9 +397,10 @@ public void testSortAggregateFunction() { .matches( logicalProject( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); @@ -402,9 +409,10 @@ public void testSortAggregateFunction() { PlanChecker.from(connectContext).analyze(sql) .matches( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true))))); sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 ORDER BY SUM(a2)"; @@ -420,9 +428,10 @@ public void testSortAggregateFunction() { PlanChecker.from(connectContext).analyze(sql) .matches( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true))))); sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 ORDER BY MIN(pk)"; @@ -444,9 +453,10 @@ public void testSortAggregateFunction() { .matches( logicalProject( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(minPK.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); @@ -455,9 +465,10 @@ public void testSortAggregateFunction() { PlanChecker.from(connectContext).analyze(sql) .matches( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A2.toSlot(), true, true))))); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 ORDER BY SUM(a1 + a2 + 3)"; @@ -467,9 +478,10 @@ public void testSortAggregateFunction() { .matches( logicalProject( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A23.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot())))); @@ -479,9 +491,10 @@ public void testSortAggregateFunction() { .matches( logicalProject( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(countStar.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); } @@ -495,6 +508,10 @@ void testComplexQueryWithOrderBy() { new ExprId(0), "pk", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") ); + SlotReference pk1 = new SlotReference( + new ExprId(6), "(pk + 1)", IntegerType.INSTANCE, true, + ImmutableList.of() + ); SlotReference a1 = new SlotReference( new ExprId(1), "a1", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") @@ -503,40 +520,41 @@ void testComplexQueryWithOrderBy() { new ExprId(2), "a2", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") ); - Alias pk1 = new Alias(new ExprId(6), new Add(pk, Literal.of((byte) 1)), "(pk + 1)"); Alias pk11 = new Alias(new ExprId(7), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)"); Alias pk2 = new Alias(new ExprId(8), new Add(pk, Literal.of((byte) 2)), "(pk + 2)"); Alias sumA1 = new Alias(new ExprId(9), new Sum(a1), "SUM(a1)"); + Alias countA1 = new Alias(new ExprId(13), new Count(a1), "count(a1)"); Alias countA11 = new Alias(new ExprId(10), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)"); Alias sumA1A2 = new Alias(new ExprId(11), new Sum(new Add(a1, a2)), "SUM((a1 + a2))"); Alias v1 = new Alias(new ExprId(12), new Count(a2), "v1"); PlanChecker.from(connectContext).analyze(sql) - .matches( - logicalProject( - logicalSort( - logicalAggregate( - logicalFilter( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) - ) - ).when(FieldChecker.check("outputExpressions", - Lists.newArrayList(pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1, pk))) - ).when(FieldChecker.check("orderKeys", - ImmutableList.of( - new OrderKey(pk, true, true), - new OrderKey(countA11.toSlot(), true, true), - new OrderKey(new Add(sumA1A2.toSlot(), new TinyIntLiteral((byte) 1)), true, true), - new OrderKey(new Add(v1.toSlot(), new TinyIntLiteral((byte) 1)), true, true), - new OrderKey(v1.toSlot(), true, true) - ) - )) - ).when(FieldChecker.check( - "projects", Lists.newArrayList( - pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1).stream() - .map(Alias::toSlot).collect(Collectors.toList())) - )); + .matches(logicalProject(logicalSort(logicalProject(logicalAggregate(logicalProject( + logicalFilter(logicalJoin(logicalOlapScan(), logicalOlapScan())))).when( + FieldChecker.check("outputExpressions", Lists.newArrayList(pk, pk1, + sumA1, countA1, sumA1A2, v1))))).when(FieldChecker.check( + "orderKeys", + ImmutableList.of(new OrderKey(pk, true, true), + new OrderKey( + countA11.toSlot(), true, true), + new OrderKey( + new Add(sumA1A2.toSlot(), + new TinyIntLiteral( + (byte) 1)), + true, true), + new OrderKey( + new Add(v1.toSlot(), + new TinyIntLiteral( + (byte) 1)), + true, true), + new OrderKey(v1.toSlot(), true, true))))) + .when(FieldChecker.check("projects", + Lists.newArrayList(pk1, + pk11.toSlot(), + pk2.toSlot(), + sumA1.toSlot(), + countA11.toSlot(), + sumA1A2.toSlot(), + v1.toSlot())))); } @Test diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java index 04e84ab8e89d86..5c43d7274d3fcc 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java @@ -299,15 +299,16 @@ public void pruneAggregateOutput() { .matches( logicalProject( logicalSubQueryAlias( - logicalAggregate( logicalProject( - logicalOlapScan() - ).when(p -> getOutputQualifiedNames(p).equals( - ImmutableList.of("default_cluster:test.student.id") - )) - ).when(agg -> getOutputQualifiedNames(agg.getOutputs()).equals( - ImmutableList.of("default_cluster:test.student.id") - )) + logicalAggregate( + logicalProject( + logicalOlapScan() + ).when(p -> getOutputQualifiedNames(p).equals( + ImmutableList.of("default_cluster:test.student.id") + )) + ).when(agg -> getOutputQualifiedNames(agg.getOutputs()).equals( + ImmutableList.of("default_cluster:test.student.id") + ))) ) ) ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java index 3fca54eed96a29..9bc0b00f50f532 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java @@ -24,6 +24,7 @@ import org.apache.doris.catalog.PartitionInfo; import org.apache.doris.catalog.Type; import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite; +import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Slot; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashConditionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashConditionTest.java index 29cc509d954988..dfad75d5d8042a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashConditionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashConditionTest.java @@ -135,20 +135,22 @@ public void testAggNodeCase() { .applyTopDown(new FindHashConditionForJoin()) .applyTopDown(new PushdownExpressionsInHashCondition()) .matches( - logicalProject( - logicalJoin( - logicalProject( - logicalOlapScan() - ), - logicalProject( - logicalSubQueryAlias( - logicalAggregate( - logicalOlapScan() - ) + logicalProject( + logicalJoin( + logicalProject( + logicalOlapScan() + ), + logicalProject( + logicalSubQueryAlias( + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan() + ))) + ) + ) ) - ) ) - ) ); } @@ -168,8 +170,12 @@ public void testSortNodeCase() { logicalProject( logicalSubQueryAlias( logicalSort( - logicalAggregate( - logicalOlapScan() + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan() + ) + ) ) ) ) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java index a3bd46eb4f2f4e..ed5a96933db105 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java @@ -20,6 +20,7 @@ import org.apache.doris.common.FeConstants; import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject; import org.apache.doris.nereids.rules.rewrite.MergeProjects; +import org.apache.doris.nereids.rules.rewrite.PushdownFilterThroughProject; import org.apache.doris.nereids.trees.plans.PreAggStatus; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; @@ -188,7 +189,8 @@ public void testWithFilterAndProject() { PlanChecker.from(connectContext) .analyze(sql) .applyBottomUp(new LogicalSubQueryAliasToLogicalProject()) - .applyTopDown(new MergeProjects()) + .applyTopDown(new PushdownFilterThroughProject()) + .applyBottomUp(new MergeProjects()) .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) .matches(logicalOlapScan().when(scan -> { diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out index 8c934fb1876e22..0aa36ae310959b 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out @@ -23,7 +23,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------------PhysicalProject ------------------PhysicalOlapScan[customer] --------------PhysicalDistribute -----------------hashJoin[INNER_JOIN](ctr1.ctr_store_sk = ctr2.ctr_store_sk)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +----------------hashJoin[INNER_JOIN](ctr1.ctr_store_sk = ctr2.ctr_store_sk)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE)) ------------------PhysicalProject --------------------hashJoin[INNER_JOIN](store.s_store_sk = ctr1.ctr_store_sk) ----------------------PhysicalDistribute @@ -32,11 +32,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------------PhysicalProject --------------------------filter((cast(s_state as VARCHAR(*)) = 'SD')) ----------------------------PhysicalOlapScan[store] -------------------PhysicalProject ---------------------hashAgg[GLOBAL] -----------------------PhysicalDistribute -------------------------hashAgg[LOCAL] ---------------------------PhysicalDistribute -----------------------------PhysicalProject -------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute +----------------------hashAgg[LOCAL] +------------------------PhysicalDistribute +--------------------------PhysicalProject +----------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out index df28c5bee47b66..83982f37827139 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out @@ -24,7 +24,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------PhysicalDistribute --------PhysicalTopN ----------PhysicalProject -------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE)) --------------hashJoin[INNER_JOIN](ctr1.ctr_customer_sk = customer.c_customer_sk) ----------------PhysicalDistribute ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) @@ -38,11 +38,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------------------------filter((cast(ca_state as VARCHAR(*)) = 'IN')) ----------------------------PhysicalOlapScan[customer_address] --------------PhysicalDistribute -----------------PhysicalProject -------------------hashAgg[GLOBAL] ---------------------PhysicalDistribute -----------------------hashAgg[LOCAL] -------------------------PhysicalDistribute ---------------------------PhysicalProject -----------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) +----------------hashAgg[GLOBAL] +------------------PhysicalDistribute +--------------------hashAgg[LOCAL] +----------------------PhysicalDistribute +------------------------PhysicalProject +--------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query51.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query51.out index 8ba49dc8d60bdd..b8d6435601a5d9 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query51.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query51.out @@ -14,30 +14,32 @@ PhysicalResultSink ----------------------PhysicalWindow ------------------------PhysicalQuickSort --------------------------PhysicalDistribute -----------------------------hashAgg[GLOBAL] -------------------------------PhysicalDistribute ---------------------------------hashAgg[LOCAL] -----------------------------------PhysicalProject -------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk = date_dim.d_date_sk) ---------------------------------------PhysicalProject -----------------------------------------PhysicalOlapScan[store_sales] ---------------------------------------PhysicalDistribute +----------------------------PhysicalProject +------------------------------hashAgg[GLOBAL] +--------------------------------PhysicalDistribute +----------------------------------hashAgg[LOCAL] +------------------------------------PhysicalProject +--------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk = date_dim.d_date_sk) ----------------------------------------PhysicalProject -------------------------------------------filter((date_dim.d_month_seq <= 1227)(date_dim.d_month_seq >= 1216)) ---------------------------------------------PhysicalOlapScan[date_dim] +------------------------------------------PhysicalOlapScan[store_sales] +----------------------------------------PhysicalDistribute +------------------------------------------PhysicalProject +--------------------------------------------filter((date_dim.d_month_seq <= 1227)(date_dim.d_month_seq >= 1216)) +----------------------------------------------PhysicalOlapScan[date_dim] --------------------PhysicalProject ----------------------PhysicalWindow ------------------------PhysicalQuickSort --------------------------PhysicalDistribute -----------------------------hashAgg[GLOBAL] -------------------------------PhysicalDistribute ---------------------------------hashAgg[LOCAL] -----------------------------------PhysicalProject -------------------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = date_dim.d_date_sk) ---------------------------------------PhysicalProject -----------------------------------------PhysicalOlapScan[web_sales] ---------------------------------------PhysicalDistribute +----------------------------PhysicalProject +------------------------------hashAgg[GLOBAL] +--------------------------------PhysicalDistribute +----------------------------------hashAgg[LOCAL] +------------------------------------PhysicalProject +--------------------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = date_dim.d_date_sk) ----------------------------------------PhysicalProject -------------------------------------------filter((date_dim.d_month_seq >= 1216)(date_dim.d_month_seq <= 1227)) ---------------------------------------------PhysicalOlapScan[date_dim] +------------------------------------------PhysicalOlapScan[web_sales] +----------------------------------------PhysicalDistribute +------------------------------------------PhysicalProject +--------------------------------------------filter((date_dim.d_month_seq >= 1216)(date_dim.d_month_seq <= 1227)) +----------------------------------------------PhysicalOlapScan[date_dim] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out index 77c7b273ba3d0f..bfcec6ce4127f4 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out @@ -24,7 +24,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------PhysicalDistribute --------PhysicalTopN ----------PhysicalProject -------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE)) --------------hashJoin[INNER_JOIN](ctr1.ctr_customer_sk = customer.c_customer_sk) ----------------PhysicalDistribute ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) @@ -38,11 +38,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------------------------filter((cast(ca_state as VARCHAR(*)) = 'CA')) ----------------------------PhysicalOlapScan[customer_address] --------------PhysicalDistribute -----------------PhysicalProject -------------------hashAgg[GLOBAL] ---------------------PhysicalDistribute -----------------------hashAgg[LOCAL] -------------------------PhysicalDistribute ---------------------------PhysicalProject -----------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) +----------------hashAgg[GLOBAL] +------------------PhysicalDistribute +--------------------hashAgg[LOCAL] +----------------------PhysicalDistribute +------------------------PhysicalProject +--------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out index 6114877bc9dd66..9913e27f5fb37b 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out @@ -9,13 +9,12 @@ PhysicalResultSink ------------PhysicalDistribute --------------PhysicalProject ----------------hashJoin[INNER_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity))) -------------------PhysicalProject ---------------------hashAgg[GLOBAL] -----------------------PhysicalDistribute -------------------------hashAgg[LOCAL] ---------------------------PhysicalProject -----------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01)) -------------------------------PhysicalOlapScan[lineitem] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute +----------------------hashAgg[LOCAL] +------------------------PhysicalProject +--------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01)) +----------------------------PhysicalOlapScan[lineitem] ------------------PhysicalDistribute --------------------hashJoin[LEFT_SEMI_JOIN](partsupp.ps_partkey = part.p_partkey) ----------------------PhysicalProject diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out index 6114877bc9dd66..9913e27f5fb37b 100644 --- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out +++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out @@ -9,13 +9,12 @@ PhysicalResultSink ------------PhysicalDistribute --------------PhysicalProject ----------------hashJoin[INNER_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity))) -------------------PhysicalProject ---------------------hashAgg[GLOBAL] -----------------------PhysicalDistribute -------------------------hashAgg[LOCAL] ---------------------------PhysicalProject -----------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01)) -------------------------------PhysicalOlapScan[lineitem] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute +----------------------hashAgg[LOCAL] +------------------------PhysicalProject +--------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01)) +----------------------------PhysicalOlapScan[lineitem] ------------------PhysicalDistribute --------------------hashJoin[LEFT_SEMI_JOIN](partsupp.ps_partkey = part.p_partkey) ----------------------PhysicalProject From ee8012311042b152d52a497fe53bf7533e2cd090 Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Thu, 24 Aug 2023 21:51:27 +0800 Subject: [PATCH 3/4] fix fe ut --- .../analysis/FillUpMissingSlotsTest.java | 118 ++++++++++-------- 1 file changed, 66 insertions(+), 52 deletions(-) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java index 64405e3bc1c22d..8cacb4609186c3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java @@ -46,8 +46,6 @@ import com.google.common.collect.Lists; import org.junit.jupiter.api.Test; -import java.util.stream.Collectors; - public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements MemoPatternMatchSupported { @Override @@ -94,7 +92,8 @@ public void testHavingGroupBySlot() { ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))))); sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0"; - SlotReference value = new SlotReference(new ExprId(3), "value", TinyIntType.INSTANCE, true, ImmutableList.of()); + SlotReference value = new SlotReference(new ExprId(3), "value", TinyIntType.INSTANCE, true, + ImmutableList.of()); PlanChecker.from(connectContext).analyze(sql) .applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE)) .matches( @@ -159,7 +158,7 @@ public void testHavingAggregateFunction() { logicalProject( logicalAggregate( logicalProject(logicalOlapScan()) - )).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); @@ -167,12 +166,15 @@ public void testHavingAggregateFunction() { sumA2 = new Alias(new ExprId(3), new Sum(a2), "SUM(a2)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan() + ) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))))); sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING SUM(a2) > 0"; a1 = new SlotReference( @@ -186,20 +188,24 @@ public void testHavingAggregateFunction() { Alias value = new Alias(new ExprId(3), new Sum(a2), "value"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))))); sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING value > 0"; PlanChecker.from(connectContext).analyze(sql) .matches( logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))) ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))); sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0"; @@ -219,49 +225,53 @@ public void testHavingAggregateFunction() { Alias minPK = new Alias(new ExprId(4), new Min(pk), "min(pk)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2) > 0"; Alias sumA1A2 = new Alias(new ExprId(3), new Sum(new Add(a1, a2)), "SUM((a1 + a2))"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L))))))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L))))))); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0"; Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new TinyIntLiteral((byte) 3))), "sum(((a1 + a2) + 3))"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot())))); sql = "SELECT a1 FROM t1 GROUP BY a1 HAVING COUNT(*) > 0"; Alias countStar = new Alias(new ExprId(3), new Count(), "count(*)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); } @Test @@ -335,6 +345,10 @@ void testComplexQueryWithHaving() { new ExprId(0), "pk", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") ); + SlotReference pk1 = new SlotReference( + new ExprId(6), "(pk + 1)", IntegerType.INSTANCE, true, + ImmutableList.of() + ); SlotReference a1 = new SlotReference( new ExprId(1), "a1", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") @@ -343,11 +357,11 @@ void testComplexQueryWithHaving() { new ExprId(2), "a2", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") ); - Alias pk1 = new Alias(new ExprId(6), new Add(pk, Literal.of((byte) 1)), "(pk + 1)"); Alias pk11 = new Alias(new ExprId(7), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)"); Alias pk2 = new Alias(new ExprId(8), new Add(pk, Literal.of((byte) 2)), "(pk + 2)"); Alias sumA1 = new Alias(new ExprId(9), new Sum(a1), "SUM(a1)"); - Alias countA11 = new Alias(new ExprId(10), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)"); + Alias countA1 = new Alias(new ExprId(13), new Count(a1), "count(a1)"); + Alias countA11 = new Alias(new ExprId(10), new Add(countA1.toSlot(), Literal.of((byte) 1)), "(COUNT(a1) + 1)"); Alias sumA1A2 = new Alias(new ExprId(11), new Sum(new Add(a1, a2)), "SUM((a1 + a2))"); Alias v1 = new Alias(new ExprId(12), new Count(a2), "v1"); PlanChecker.from(connectContext).analyze(sql) @@ -363,8 +377,8 @@ void testComplexQueryWithHaving() { logicalOlapScan() ) )) - )).when(FieldChecker.check("outputExpressions", - Lists.newArrayList(pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1, pk))) + ).when(FieldChecker.check("outputExpressions", + Lists.newArrayList(pk, pk1, sumA1, countA1, sumA1A2, v1)))) ).when(FieldChecker.check("conjuncts", ImmutableSet.of( new GreaterThan(pk.toSlot(), Literal.of((byte) 0)), @@ -376,8 +390,8 @@ void testComplexQueryWithHaving() { ) ).when(FieldChecker.check( "projects", Lists.newArrayList( - pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1).stream() - .map(Alias::toSlot).collect(Collectors.toList())) + pk1, pk11.toSlot(), pk2.toSlot(), sumA1.toSlot(), countA11.toSlot(), sumA1A2.toSlot(), v1.toSlot()) + ) )); } From cd18d91d9ba2ed286322f053103c54b4abda374b Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Fri, 25 Aug 2023 17:10:54 +0800 Subject: [PATCH 4/4] modified based on comments --- .../doris/nereids/jobs/executor/Analyzer.java | 3 +-- .../rules/analysis/SubqueryToApply.java | 16 ++++++------- .../EliminateGroupByConstantTest.java | 4 +--- .../NormalizeAggregateTest.java | 3 +-- .../rewrite/PushdownAliasThroughJoinTest.java | 23 +++++++++++++++++++ 5 files changed, 33 insertions(+), 16 deletions(-) rename fe/fe-core/src/test/java/org/apache/doris/nereids/rules/{rewrite => analysis}/EliminateGroupByConstantTest.java (97%) rename fe/fe-core/src/test/java/org/apache/doris/nereids/rules/{rewrite => analysis}/NormalizeAggregateTest.java (98%) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java index f2acdc0a38a6ef..ed67b44f1fd619 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java @@ -119,8 +119,7 @@ private static List buildAnalyzeJobs(Optional c bottomUp(new CheckAnalysis()), topDown(new EliminateGroupByConstant()), topDown(new NormalizeAggregate()), - bottomUp(new SubqueryToApply()), - bottomUp(new CheckAnalysis()) + bottomUp(new SubqueryToApply()) ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java index 26a5bd10adf129..c28e82f680919c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java @@ -42,7 +42,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; import java.util.Collection; import java.util.LinkedHashMap; @@ -66,8 +65,8 @@ public List buildRules() { logicalFilter().thenApply(ctx -> { LogicalFilter filter = ctx.root; - ImmutableList subqueryExprsList = filter.getConjuncts().stream() - .map(e -> (Set) e.collect(SubqueryExpr.class::isInstance)) + ImmutableList> subqueryExprsList = filter.getConjuncts().stream() + .map(e -> (Set) e.collect(SubqueryExpr.class::isInstance)) .collect(ImmutableList.toImmutableList()); if (subqueryExprsList.stream() .flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) { @@ -115,15 +114,14 @@ public List buildRules() { ), RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> { LogicalProject project = ctx.root; - ImmutableList subqueryExprsList = project.getProjects().stream() - .map(e -> (Set) e.collect(SubqueryExpr.class::isInstance)) + ImmutableList> subqueryExprsList = project.getProjects().stream() + .map(e -> (Set) e.collect(SubqueryExpr.class::isInstance)) .collect(ImmutableList.toImmutableList()); - if (subqueryExprsList.stream().flatMap(Collection::stream) - .noneMatch(SubqueryExpr.class::isInstance)) { + if (subqueryExprsList.stream().flatMap(Collection::stream).count() == 0) { return project; } List oldProjects = ImmutableList.copyOf(project.getProjects()); - List newProjects = Lists.newArrayList(); + ImmutableList.Builder newProjects = new ImmutableList.Builder<>(); LogicalPlan childPlan = (LogicalPlan) project.child(); LogicalPlan applyPlan; for (int i = 0; i < subqueryExprsList.size(); ++i) { @@ -150,7 +148,7 @@ public List buildRules() { newProjects.add((NamedExpression) newProject); } - return project.withProjectsAndChild(newProjects, childPlan); + return project.withProjectsAndChild(newProjects.build(), childPlan); })) ); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java similarity index 97% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java index 9bc0b00f50f532..c35b983911c859 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.rewrite; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.catalog.AggregateType; import org.apache.doris.catalog.Column; @@ -23,8 +23,6 @@ import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.PartitionInfo; import org.apache.doris.catalog.Type; -import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite; -import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Slot; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java similarity index 98% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java index 29280e29c7c732..3808fd1842810f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.rewrite; +package org.apache.doris.nereids.rules.analysis; -import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoinTest.java index 5667f3f2c5ad3d..5a98b07bcf041c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoinTest.java @@ -18,6 +18,10 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -30,6 +34,8 @@ import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; +import java.util.List; + class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported { private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); @@ -99,4 +105,21 @@ void testJustRightSide() { && project.getProjects().get(1).toSql().equals("2name")) ); } + + @Test + void testNoPushdownMarkJoin() { + List projects = + ImmutableList.of(new MarkJoinSlotReference(new ExprId(101), "markSlot1", false), + new Alias(new MarkJoinSlotReference(new ExprId(102), "markSlot2", false), + "markSlot2")); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)).projectExprs(projects).build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushdownAliasThroughJoin()) + .matches(logicalProject(logicalJoin(logicalOlapScan(), logicalOlapScan())) + .when(project -> project.getProjects().get(0).toSql().equals("markSlot1") + && project.getProjects().get(1).toSql() + .equals("markSlot2 AS `markSlot2`"))); + } }