diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java index 1fb1d7eecdf080..ed67b44f1fd619 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java @@ -30,7 +30,9 @@ import org.apache.doris.nereids.rules.analysis.CheckAnalysis; import org.apache.doris.nereids.rules.analysis.CheckBound; import org.apache.doris.nereids.rules.analysis.CheckPolicy; +import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant; import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.analysis.NormalizeRepeat; import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate; import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate; @@ -110,9 +112,14 @@ private static List buildAnalyzeJobs(Optional c // LogicalProject for normalize. This rule depends on FillUpMissingSlots to fill up slots. new NormalizeRepeat() ), - bottomUp(new SubqueryToApply()), bottomUp(new AdjustAggregateNullableForEmptySet()), - bottomUp(new CheckAnalysis()) + // run CheckAnalysis before EliminateGroupByConstant in order to report error message correctly like bellow + // select SUM(lo_tax) FROM lineorder group by 1; + // errCode = 2, detailMessage = GROUP BY expression must not contain aggregate functions: sum(lo_tax) + bottomUp(new CheckAnalysis()), + topDown(new EliminateGroupByConstant()), + topDown(new NormalizeAggregate()), + bottomUp(new SubqueryToApply()) ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index ec2ea06eace9f0..986947c262575e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -25,7 +25,9 @@ import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet; import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount; import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite; +import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant; import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite; import org.apache.doris.nereids.rules.expression.ExpressionNormalization; import org.apache.doris.nereids.rules.expression.ExpressionOptimization; @@ -52,7 +54,6 @@ import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition; import org.apache.doris.nereids.rules.rewrite.EliminateEmptyRelation; import org.apache.doris.nereids.rules.rewrite.EliminateFilter; -import org.apache.doris.nereids.rules.rewrite.EliminateGroupByConstant; import org.apache.doris.nereids.rules.rewrite.EliminateLimit; import org.apache.doris.nereids.rules.rewrite.EliminateNotNull; import org.apache.doris.nereids.rules.rewrite.EliminateNullAwareLeftAntiJoin; @@ -74,12 +75,12 @@ import org.apache.doris.nereids.rules.rewrite.MergeOneRowRelationIntoUnion; import org.apache.doris.nereids.rules.rewrite.MergeProjects; import org.apache.doris.nereids.rules.rewrite.MergeSetOperations; -import org.apache.doris.nereids.rules.rewrite.NormalizeAggregate; import org.apache.doris.nereids.rules.rewrite.NormalizeSort; import org.apache.doris.nereids.rules.rewrite.PruneFileScanPartition; import org.apache.doris.nereids.rules.rewrite.PruneOlapScanPartition; import org.apache.doris.nereids.rules.rewrite.PruneOlapScanTablet; import org.apache.doris.nereids.rules.rewrite.PullUpCteAnchor; +import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderApply; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan; import org.apache.doris.nereids.rules.rewrite.PushFilterInsideJoin; @@ -139,6 +140,10 @@ public class Rewriter extends AbstractBatchJobExecutor { ), // subquery unnesting relay on ExpressionNormalization to extract common factor expression topic("Subquery unnesting", + // after doing NormalizeAggregate in analysis job + // we need run the following 2 rules to make AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION work + bottomUp(new PullUpProjectUnderApply()), + topDown(new PushdownFilterThroughProject()), costBased( custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION, AggScalarSubQueryToWindowFunction::new) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstant.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java similarity index 96% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstant.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java index f5a01fe530543a..e7fa14e5cb24f6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstant.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.rewrite; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.Plan; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java similarity index 98% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index eb683e8b5835db..6a141dce7a795c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.rewrite; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java index 6dfe95c1167b2b..c28e82f680919c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java @@ -21,9 +21,7 @@ import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.BinaryOperator; -import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Exists; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InSubquery; @@ -47,7 +45,6 @@ import java.util.Collection; import java.util.LinkedHashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -68,8 +65,8 @@ public List buildRules() { logicalFilter().thenApply(ctx -> { LogicalFilter filter = ctx.root; - ImmutableList subqueryExprsList = filter.getConjuncts().stream() - .map(e -> (Set) e.collect(SubqueryExpr.class::isInstance)) + ImmutableList> subqueryExprsList = filter.getConjuncts().stream() + .map(e -> (Set) e.collect(SubqueryExpr.class::isInstance)) .collect(ImmutableList.toImmutableList()); if (subqueryExprsList.stream() .flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) { @@ -104,8 +101,7 @@ public List buildRules() { tmpPlan = applyPlan; newConjuncts.add(conjunct); } - Set conjuncts = new LinkedHashSet<>(); - conjuncts.addAll(newConjuncts.build()); + Set conjuncts = ImmutableSet.copyOf(newConjuncts.build()); Plan newFilter = new LogicalFilter<>(conjuncts, applyPlan); if (conjuncts.stream().flatMap(c -> c.children().stream()) .anyMatch(MarkJoinSlotReference.class::isInstance)) { @@ -116,36 +112,44 @@ public List buildRules() { return new LogicalFilter<>(conjuncts, applyPlan); }) ), - RuleType.PROJECT_SUBQUERY_TO_APPLY.build( - logicalProject().thenApply(ctx -> { - LogicalProject project = ctx.root; - Set subqueryExprs = new LinkedHashSet<>(); - project.getProjects().stream() - .filter(Alias.class::isInstance) - .map(Alias.class::cast) - .filter(alias -> alias.child() instanceof CaseWhen) - .forEach(alias -> alias.child().children().stream() - .forEach(e -> - subqueryExprs.addAll(e.collect(SubqueryExpr.class::isInstance)))); - if (subqueryExprs.isEmpty()) { - return project; - } - - SubqueryContext context = new SubqueryContext(subqueryExprs); - return new LogicalProject(project.getProjects().stream() - .map(p -> p.withChildren( - new ReplaceSubquery(ctx.statementContext, true) - .replace(p, context))) - .collect(ImmutableList.toImmutableList()), - subqueryToApply( - subqueryExprs.stream().collect(ImmutableList.toImmutableList()), - (LogicalPlan) project.child(), - context.getSubqueryToMarkJoinSlot(), - ctx.cascadesContext, - Optional.empty(), true - )); - }) - ) + RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> { + LogicalProject project = ctx.root; + ImmutableList> subqueryExprsList = project.getProjects().stream() + .map(e -> (Set) e.collect(SubqueryExpr.class::isInstance)) + .collect(ImmutableList.toImmutableList()); + if (subqueryExprsList.stream().flatMap(Collection::stream).count() == 0) { + return project; + } + List oldProjects = ImmutableList.copyOf(project.getProjects()); + ImmutableList.Builder newProjects = new ImmutableList.Builder<>(); + LogicalPlan childPlan = (LogicalPlan) project.child(); + LogicalPlan applyPlan; + for (int i = 0; i < subqueryExprsList.size(); ++i) { + Set subqueryExprs = subqueryExprsList.get(i); + if (subqueryExprs.isEmpty()) { + newProjects.add(oldProjects.get(i)); + continue; + } + + // first step: Replace the subquery in logcialProject's project list + // second step: Replace subquery with LogicalApply + ReplaceSubquery replaceSubquery = + new ReplaceSubquery(ctx.statementContext, true); + SubqueryContext context = new SubqueryContext(subqueryExprs); + Expression newProject = + replaceSubquery.replace(oldProjects.get(i), context); + + applyPlan = subqueryToApply( + subqueryExprs.stream().collect(ImmutableList.toImmutableList()), + childPlan, context.getSubqueryToMarkJoinSlot(), + ctx.cascadesContext, + Optional.of(newProject), true); + childPlan = applyPlan; + newProjects.add((NamedExpression) newProject); + } + + return project.withProjectsAndChild(newProjects.build(), childPlan); + })) ); } @@ -249,28 +253,30 @@ public Expression visitExistsSubquery(Exists exists, SubqueryContext context) { // The result set when NULL is specified in the subquery and still evaluates to TRUE by using EXISTS // When the number of rows returned is empty, agg will return null, so if there is more agg, // it will always consider the returned result to be true + boolean needCreateMarkJoinSlot = isMarkJoin || isProject; MarkJoinSlotReference markJoinSlotReference = null; - if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && isMarkJoin) { + if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && needCreateMarkJoinSlot) { markJoinSlotReference = new MarkJoinSlotReference(statementContext.generateColumnName(), true); - } else if (isMarkJoin) { + } else if (needCreateMarkJoinSlot) { markJoinSlotReference = new MarkJoinSlotReference(statementContext.generateColumnName()); } - if (isMarkJoin) { + if (needCreateMarkJoinSlot) { context.setSubqueryToMarkJoinSlot(exists, Optional.of(markJoinSlotReference)); } - return isMarkJoin ? markJoinSlotReference : BooleanLiteral.TRUE; + return needCreateMarkJoinSlot ? markJoinSlotReference : BooleanLiteral.TRUE; } @Override public Expression visitInSubquery(InSubquery in, SubqueryContext context) { MarkJoinSlotReference markJoinSlotReference = new MarkJoinSlotReference(statementContext.generateColumnName()); - if (isMarkJoin) { + boolean needCreateMarkJoinSlot = isMarkJoin || isProject; + if (needCreateMarkJoinSlot) { context.setSubqueryToMarkJoinSlot(in, Optional.of(markJoinSlotReference)); } - return isMarkJoin ? markJoinSlotReference : BooleanLiteral.TRUE; + return needCreateMarkJoinSlot ? markJoinSlotReference : BooleanLiteral.TRUE; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 03962ff75205f7..f0971c94ba6e84 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -29,8 +29,8 @@ import org.apache.doris.nereids.properties.RequireProperties; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE; -import org.apache.doris.nereids.rules.rewrite.NormalizeAggregate; import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Cast; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoin.java index 42649cbac1ead4..7839fcfe95d031 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoin.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; @@ -45,7 +46,9 @@ public class PushdownAliasThroughJoin extends OneRewriteRuleFactory { public Rule build() { return logicalProject(logicalJoin()) .when(project -> project.getProjects().stream().allMatch(expr -> - (expr instanceof Slot) || (expr instanceof Alias && ((Alias) expr).child() instanceof Slot))) + (expr instanceof Slot && !(expr instanceof MarkJoinSlotReference)) + || (expr instanceof Alias && ((Alias) expr).child() instanceof Slot + && !(((Alias) expr).child() instanceof MarkJoinSlotReference)))) .when(project -> project.getProjects().stream().anyMatch(expr -> expr instanceof Alias)) .then(project -> { LogicalJoin join = project.child(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java index c9233d5c146b70..11456e8f94c3b3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java @@ -94,7 +94,7 @@ public String toString() { StringBuilder output = new StringBuilder("CASE"); for (Expression child : children()) { if (child instanceof WhenClause) { - output.append(child); + output.append(child.toString()); } else { output.append(" ELSE ").append(child.toString()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DoubleLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DoubleLiteral.java index bdd26460c06b96..b155fe307563c8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DoubleLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DoubleLiteral.java @@ -58,4 +58,9 @@ public String toString() { nf.setGroupingUsed(false); return nf.format(value); } + + @Override + public String getStringValue() { + return toString(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/FloatLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/FloatLiteral.java index 4fff7445efae4d..95549901dda2c2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/FloatLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/FloatLiteral.java @@ -22,6 +22,8 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.FloatType; +import java.text.NumberFormat; + /** * float type literal */ @@ -48,4 +50,11 @@ public R accept(ExpressionVisitor visitor, C context) { public LiteralExpr toLegacyLiteral() { return new org.apache.doris.analysis.FloatLiteral((double) value, Type.FLOAT); } + + @Override + public String getStringValue() { + NumberFormat nf = NumberFormat.getInstance(); + nf.setGroupingUsed(false); + return nf.format(value); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java index 522f198e3ff774..ef5a32e2d3bcdb 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java @@ -140,7 +140,7 @@ public void testCTEInHavingAndSubquery() { logicalFilter( logicalProject( logicalJoin( - logicalAggregate(), + logicalProject(logicalAggregate()), logicalProject() ) ) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java index bf060d7e5dd8a4..73422bee702400 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java @@ -156,18 +156,20 @@ public void testWhereSql2AfterAnalyzed() { .matchesNotCheck( logicalApply( any(), - logicalAggregate( - logicalFilter() - ).when(FieldChecker.check("outputExpressions", ImmutableList.of( - new Alias(new ExprId(7), - (new Sum( - new SlotReference(new ExprId(4), "k3", - BigIntType.INSTANCE, true, - ImmutableList.of( - "default_cluster:test", - "t7")))).withAlwaysNullable( - true), - "sum(k3)")))) + logicalProject( + logicalAggregate( + logicalProject() + ).when(FieldChecker.check("outputExpressions", ImmutableList.of( + new Alias(new ExprId(7), + (new Sum( + new SlotReference(new ExprId(4), "k3", + BigIntType.INSTANCE, true, + ImmutableList.of( + "default_cluster:test", + "t7")))).withAlwaysNullable( + true), + "sum(k3)")))) + ) ).when(FieldChecker.check("correlationSlot", ImmutableList.of( new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true, ImmutableList.of("default_cluster:test", "t6")) @@ -383,28 +385,32 @@ public void testSql10AfterAnalyze() { logicalProject( logicalApply( any(), - logicalAggregate( - logicalSubQueryAlias( + logicalProject( + logicalAggregate( logicalProject( - logicalFilter() - ).when(p -> p.getProjects().equals(ImmutableList.of( - new Alias(new ExprId(7), new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE, - true, - ImmutableList.of("default_cluster:test", "t7")), "aa") - ))) - ) - .when(a -> a.getAlias().equals("t2")) - .when(a -> a.getOutput().equals(ImmutableList.of( - new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, - true, ImmutableList.of("t2")) + logicalSubQueryAlias( + logicalProject( + logicalFilter() + ).when(p -> p.getProjects().equals(ImmutableList.of( + new Alias(new ExprId(7), new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE, + true, + ImmutableList.of("default_cluster:test", "t7")), "aa") + ))) + ) + .when(a -> a.getAlias().equals("t2")) + .when(a -> a.getOutput().equals(ImmutableList.of( + new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, + true, ImmutableList.of("t2")) + ))) + ) + ).when(agg -> agg.getOutputExpressions().equals(ImmutableList.of( + new Alias(new ExprId(8), + (new Max(new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, + true, + ImmutableList.of("t2")))).withAlwaysNullable(true), "max(aa)") ))) - ).when(agg -> agg.getOutputExpressions().equals(ImmutableList.of( - new Alias(new ExprId(8), - (new Max(new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, - true, - ImmutableList.of("t2")))).withAlwaysNullable(true), "max(aa)") - ))) - .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of())) + .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of())) + ) ) .when(apply -> apply.getCorrelationSlot().equals(ImmutableList.of( new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true, diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java index 0a3334b4cf0c2b..dc05ec062637ea 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java @@ -90,10 +90,11 @@ public void testGroupByOnJoin() { join ); PlanChecker checker = PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate); - LogicalAggregate plan = (LogicalAggregate) checker.getCascadesContext().getMemo().copyOut(); + LogicalAggregate plan = (LogicalAggregate) ((LogicalProject) checker.getCascadesContext() + .getMemo().copyOut()).child(); SlotReference groupByKey = (SlotReference) plan.getGroupByExpressions().get(0); - SlotReference t1id = (SlotReference) ((LogicalJoin) plan.child()).left().getOutput().get(0); - SlotReference t2id = (SlotReference) ((LogicalJoin) plan.child()).right().getOutput().get(0); + SlotReference t1id = (SlotReference) ((LogicalJoin) plan.child().child(0)).left().getOutput().get(0); + SlotReference t2id = (SlotReference) ((LogicalJoin) plan.child().child(0)).right().getOutput().get(0); Assertions.assertEquals(groupByKey.getExprId(), t1id.getExprId()); Assertions.assertNotEquals(t1id.getExprId(), t2id.getExprId()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java similarity index 98% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java index 3fca54eed96a29..c35b983911c859 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.rewrite; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.catalog.AggregateType; import org.apache.doris.catalog.Column; @@ -23,7 +23,6 @@ import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.PartitionInfo; import org.apache.doris.catalog.Type; -import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Slot; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java index 03cc549bc2c1fa..8cacb4609186c3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java @@ -35,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.util.FieldChecker; import org.apache.doris.nereids.util.MemoPatternMatchSupported; @@ -45,8 +46,6 @@ import com.google.common.collect.Lists; import org.junit.jupiter.api.Test; -import java.util.stream.Collectors; - public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements MemoPatternMatchSupported { @Override @@ -86,35 +85,35 @@ public void testHavingGroupBySlot() { ); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1))))); + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))))); sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0"; - a1 = new SlotReference( - new ExprId(1), "a1", TinyIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") - ); - Alias value = new Alias(new ExprId(3), a1, "value"); + SlotReference value = new SlotReference(new ExprId(3), "value", TinyIntType.INSTANCE, true, + ImmutableList.of()); PlanChecker.from(connectContext).analyze(sql) .applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE)) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))))); sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING value > 0"; PlanChecker.from(connectContext).analyze(sql) .applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE)) .matches( logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value)))) ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))); sql = "SELECT SUM(a2) FROM t1 GROUP BY a1 HAVING a1 > 0"; @@ -130,13 +129,14 @@ public void testHavingGroupBySlot() { PlanChecker.from(connectContext).analyze(sql) .applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE)) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(sumA2, a1))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(sumA2.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(sumA2.toSlot())))); } @Test @@ -153,24 +153,28 @@ public void testHavingAggregateFunction() { Alias sumA2 = new Alias(new ExprId(3), new Sum(a2), "sum(a2)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING SUM(a2) > 0"; sumA2 = new Alias(new ExprId(3), new Sum(a2), "SUM(a2)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan() + ) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))))); sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING SUM(a2) > 0"; a1 = new SlotReference( @@ -184,20 +188,24 @@ public void testHavingAggregateFunction() { Alias value = new Alias(new ExprId(3), new Sum(a2), "value"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))))); sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING value > 0"; PlanChecker.from(connectContext).analyze(sql) .matches( logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))) ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))); sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0"; @@ -217,49 +225,53 @@ public void testHavingAggregateFunction() { Alias minPK = new Alias(new ExprId(4), new Min(pk), "min(pk)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2) > 0"; Alias sumA1A2 = new Alias(new ExprId(3), new Sum(new Add(a1, a2)), "SUM((a1 + a2))"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L))))))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L))))))); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0"; Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new TinyIntLiteral((byte) 3))), "sum(((a1 + a2) + 3))"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot())))); sql = "SELECT a1 FROM t1 GROUP BY a1 HAVING COUNT(*) > 0"; Alias countStar = new Alias(new ExprId(3), new Count(), "count(*)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); } @Test @@ -281,19 +293,21 @@ void testJoinWithHaving() { Alias sumB1 = new Alias(new ExprId(7), new Sum(b1), "sum(b1)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( + logicalProject( logicalFilter( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) - ) - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE), - sumB1.toSlot())))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); + logicalProject( + logicalAggregate( + logicalProject( + logicalFilter( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + )) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1))) + )).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE), + sumB1.toSlot())))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); } @Test @@ -331,6 +345,10 @@ void testComplexQueryWithHaving() { new ExprId(0), "pk", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") ); + SlotReference pk1 = new SlotReference( + new ExprId(6), "(pk + 1)", IntegerType.INSTANCE, true, + ImmutableList.of() + ); SlotReference a1 = new SlotReference( new ExprId(1), "a1", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") @@ -339,40 +357,42 @@ void testComplexQueryWithHaving() { new ExprId(2), "a2", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") ); - Alias pk1 = new Alias(new ExprId(6), new Add(pk, Literal.of((byte) 1)), "(pk + 1)"); Alias pk11 = new Alias(new ExprId(7), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)"); Alias pk2 = new Alias(new ExprId(8), new Add(pk, Literal.of((byte) 2)), "(pk + 2)"); Alias sumA1 = new Alias(new ExprId(9), new Sum(a1), "SUM(a1)"); - Alias countA11 = new Alias(new ExprId(10), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)"); + Alias countA1 = new Alias(new ExprId(13), new Count(a1), "count(a1)"); + Alias countA11 = new Alias(new ExprId(10), new Add(countA1.toSlot(), Literal.of((byte) 1)), "(COUNT(a1) + 1)"); Alias sumA1A2 = new Alias(new ExprId(11), new Sum(new Add(a1, a2)), "SUM((a1 + a2))"); Alias v1 = new Alias(new ExprId(12), new Count(a2), "v1"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( + logicalProject( logicalFilter( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) + logicalProject( + logicalAggregate( + logicalProject( + logicalFilter( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + )) + ).when(FieldChecker.check("outputExpressions", + Lists.newArrayList(pk, pk1, sumA1, countA1, sumA1A2, v1)))) + ).when(FieldChecker.check("conjuncts", + ImmutableSet.of( + new GreaterThan(pk.toSlot(), Literal.of((byte) 0)), + new GreaterThan(countA11.toSlot(), Literal.of(0L)), + new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of(0L)), + new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of(0L)), + new GreaterThan(v1.toSlot(), Literal.of(0L)) + )) ) - ).when(FieldChecker.check("outputExpressions", - Lists.newArrayList(pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1, pk))) - ).when(FieldChecker.check("conjuncts", - ImmutableSet.of( - new GreaterThan(pk.toSlot(), Literal.of((byte) 0)), - new GreaterThan(countA11.toSlot(), Literal.of(0L)), - new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of(0L)), - new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of(0L)), - new GreaterThan(v1.toSlot(), Literal.of(0L)) - )) - ) - ).when(FieldChecker.check( - "projects", Lists.newArrayList( - pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1).stream() - .map(Alias::toSlot).collect(Collectors.toList())) - )); + ).when(FieldChecker.check( + "projects", Lists.newArrayList( + pk1, pk11.toSlot(), pk2.toSlot(), sumA1.toSlot(), countA11.toSlot(), sumA1A2.toSlot(), v1.toSlot()) + ) + )); } @Test @@ -391,9 +411,10 @@ public void testSortAggregateFunction() { .matches( logicalProject( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); @@ -402,9 +423,10 @@ public void testSortAggregateFunction() { PlanChecker.from(connectContext).analyze(sql) .matches( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true))))); sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 ORDER BY SUM(a2)"; @@ -420,9 +442,10 @@ public void testSortAggregateFunction() { PlanChecker.from(connectContext).analyze(sql) .matches( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true))))); sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 ORDER BY MIN(pk)"; @@ -444,9 +467,10 @@ public void testSortAggregateFunction() { .matches( logicalProject( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(minPK.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); @@ -455,9 +479,10 @@ public void testSortAggregateFunction() { PlanChecker.from(connectContext).analyze(sql) .matches( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A2.toSlot(), true, true))))); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 ORDER BY SUM(a1 + a2 + 3)"; @@ -467,9 +492,10 @@ public void testSortAggregateFunction() { .matches( logicalProject( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A23.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot())))); @@ -479,9 +505,10 @@ public void testSortAggregateFunction() { .matches( logicalProject( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(countStar.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); } @@ -495,6 +522,10 @@ void testComplexQueryWithOrderBy() { new ExprId(0), "pk", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") ); + SlotReference pk1 = new SlotReference( + new ExprId(6), "(pk + 1)", IntegerType.INSTANCE, true, + ImmutableList.of() + ); SlotReference a1 = new SlotReference( new ExprId(1), "a1", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") @@ -503,40 +534,41 @@ void testComplexQueryWithOrderBy() { new ExprId(2), "a2", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") ); - Alias pk1 = new Alias(new ExprId(6), new Add(pk, Literal.of((byte) 1)), "(pk + 1)"); Alias pk11 = new Alias(new ExprId(7), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)"); Alias pk2 = new Alias(new ExprId(8), new Add(pk, Literal.of((byte) 2)), "(pk + 2)"); Alias sumA1 = new Alias(new ExprId(9), new Sum(a1), "SUM(a1)"); + Alias countA1 = new Alias(new ExprId(13), new Count(a1), "count(a1)"); Alias countA11 = new Alias(new ExprId(10), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)"); Alias sumA1A2 = new Alias(new ExprId(11), new Sum(new Add(a1, a2)), "SUM((a1 + a2))"); Alias v1 = new Alias(new ExprId(12), new Count(a2), "v1"); PlanChecker.from(connectContext).analyze(sql) - .matches( - logicalProject( - logicalSort( - logicalAggregate( - logicalFilter( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) - ) - ).when(FieldChecker.check("outputExpressions", - Lists.newArrayList(pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1, pk))) - ).when(FieldChecker.check("orderKeys", - ImmutableList.of( - new OrderKey(pk, true, true), - new OrderKey(countA11.toSlot(), true, true), - new OrderKey(new Add(sumA1A2.toSlot(), new TinyIntLiteral((byte) 1)), true, true), - new OrderKey(new Add(v1.toSlot(), new TinyIntLiteral((byte) 1)), true, true), - new OrderKey(v1.toSlot(), true, true) - ) - )) - ).when(FieldChecker.check( - "projects", Lists.newArrayList( - pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1).stream() - .map(Alias::toSlot).collect(Collectors.toList())) - )); + .matches(logicalProject(logicalSort(logicalProject(logicalAggregate(logicalProject( + logicalFilter(logicalJoin(logicalOlapScan(), logicalOlapScan())))).when( + FieldChecker.check("outputExpressions", Lists.newArrayList(pk, pk1, + sumA1, countA1, sumA1A2, v1))))).when(FieldChecker.check( + "orderKeys", + ImmutableList.of(new OrderKey(pk, true, true), + new OrderKey( + countA11.toSlot(), true, true), + new OrderKey( + new Add(sumA1A2.toSlot(), + new TinyIntLiteral( + (byte) 1)), + true, true), + new OrderKey( + new Add(v1.toSlot(), + new TinyIntLiteral( + (byte) 1)), + true, true), + new OrderKey(v1.toSlot(), true, true))))) + .when(FieldChecker.check("projects", + Lists.newArrayList(pk1, + pk11.toSlot(), + pk2.toSlot(), + sumA1.toSlot(), + countA11.toSlot(), + sumA1A2.toSlot(), + v1.toSlot())))); } @Test diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java similarity index 99% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java index 32f7b324f9af47..3808fd1842810f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.rewrite; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java index 6f3bfaa7e53314..34c16309181466 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.implementation.AggregateStrategies; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.AggregateExpression; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java index 04e84ab8e89d86..5c43d7274d3fcc 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java @@ -299,15 +299,16 @@ public void pruneAggregateOutput() { .matches( logicalProject( logicalSubQueryAlias( - logicalAggregate( logicalProject( - logicalOlapScan() - ).when(p -> getOutputQualifiedNames(p).equals( - ImmutableList.of("default_cluster:test.student.id") - )) - ).when(agg -> getOutputQualifiedNames(agg.getOutputs()).equals( - ImmutableList.of("default_cluster:test.student.id") - )) + logicalAggregate( + logicalProject( + logicalOlapScan() + ).when(p -> getOutputQualifiedNames(p).equals( + ImmutableList.of("default_cluster:test.student.id") + )) + ).when(agg -> getOutputQualifiedNames(agg.getOutputs()).equals( + ImmutableList.of("default_cluster:test.student.id") + ))) ) ) ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpressionTest.java index e676caa37a8c98..476131e6b068b1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpressionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpressionTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoinTest.java index 5667f3f2c5ad3d..5a98b07bcf041c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoinTest.java @@ -18,6 +18,10 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -30,6 +34,8 @@ import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; +import java.util.List; + class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported { private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); @@ -99,4 +105,21 @@ void testJustRightSide() { && project.getProjects().get(1).toSql().equals("2name")) ); } + + @Test + void testNoPushdownMarkJoin() { + List projects = + ImmutableList.of(new MarkJoinSlotReference(new ExprId(101), "markSlot1", false), + new Alias(new MarkJoinSlotReference(new ExprId(102), "markSlot2", false), + "markSlot2")); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)).projectExprs(projects).build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushdownAliasThroughJoin()) + .matches(logicalProject(logicalJoin(logicalOlapScan(), logicalOlapScan())) + .when(project -> project.getProjects().get(0).toSql().equals("markSlot1") + && project.getProjects().get(1).toSql() + .equals("markSlot2 AS `markSlot2`"))); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashConditionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashConditionTest.java index 29cc509d954988..dfad75d5d8042a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashConditionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashConditionTest.java @@ -135,20 +135,22 @@ public void testAggNodeCase() { .applyTopDown(new FindHashConditionForJoin()) .applyTopDown(new PushdownExpressionsInHashCondition()) .matches( - logicalProject( - logicalJoin( - logicalProject( - logicalOlapScan() - ), - logicalProject( - logicalSubQueryAlias( - logicalAggregate( - logicalOlapScan() - ) + logicalProject( + logicalJoin( + logicalProject( + logicalOlapScan() + ), + logicalProject( + logicalSubQueryAlias( + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan() + ))) + ) + ) ) - ) ) - ) ); } @@ -168,8 +170,12 @@ public void testSortNodeCase() { logicalProject( logicalSubQueryAlias( logicalSort( - logicalAggregate( - logicalOlapScan() + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan() + ) + ) ) ) ) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java index a3bd46eb4f2f4e..ed5a96933db105 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java @@ -20,6 +20,7 @@ import org.apache.doris.common.FeConstants; import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject; import org.apache.doris.nereids.rules.rewrite.MergeProjects; +import org.apache.doris.nereids.rules.rewrite.PushdownFilterThroughProject; import org.apache.doris.nereids.trees.plans.PreAggStatus; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; @@ -188,7 +189,8 @@ public void testWithFilterAndProject() { PlanChecker.from(connectContext) .analyze(sql) .applyBottomUp(new LogicalSubQueryAliasToLogicalProject()) - .applyTopDown(new MergeProjects()) + .applyTopDown(new PushdownFilterThroughProject()) + .applyBottomUp(new MergeProjects()) .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) .matches(logicalOlapScan().when(scan -> { diff --git a/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out new file mode 100644 index 00000000000000..5b97935639059d --- /dev/null +++ b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out @@ -0,0 +1,50 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql1 -- +3 + +-- !sql2 -- +3 + +-- !sql3 -- +3 + +-- !sql4 -- +false + +-- !sql5 -- +false + +-- !sql6 -- +true + +-- !sql7 -- +2 + +-- !sql8 -- +4 +4 + +-- !sql9 -- +4 +4 + +-- !sql10 -- +false +true + +-- !sql11 -- +false +true + +-- !sql12 -- +true +true + +-- !sql13 -- +2 +2 + +-- !sql14 -- +\N 2.0 +2020-09-09 2.0 + diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out index 8c934fb1876e22..0aa36ae310959b 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out @@ -23,7 +23,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------------PhysicalProject ------------------PhysicalOlapScan[customer] --------------PhysicalDistribute -----------------hashJoin[INNER_JOIN](ctr1.ctr_store_sk = ctr2.ctr_store_sk)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +----------------hashJoin[INNER_JOIN](ctr1.ctr_store_sk = ctr2.ctr_store_sk)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE)) ------------------PhysicalProject --------------------hashJoin[INNER_JOIN](store.s_store_sk = ctr1.ctr_store_sk) ----------------------PhysicalDistribute @@ -32,11 +32,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------------PhysicalProject --------------------------filter((cast(s_state as VARCHAR(*)) = 'SD')) ----------------------------PhysicalOlapScan[store] -------------------PhysicalProject ---------------------hashAgg[GLOBAL] -----------------------PhysicalDistribute -------------------------hashAgg[LOCAL] ---------------------------PhysicalDistribute -----------------------------PhysicalProject -------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute +----------------------hashAgg[LOCAL] +------------------------PhysicalDistribute +--------------------------PhysicalProject +----------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out index df28c5bee47b66..83982f37827139 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out @@ -24,7 +24,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------PhysicalDistribute --------PhysicalTopN ----------PhysicalProject -------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE)) --------------hashJoin[INNER_JOIN](ctr1.ctr_customer_sk = customer.c_customer_sk) ----------------PhysicalDistribute ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) @@ -38,11 +38,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------------------------filter((cast(ca_state as VARCHAR(*)) = 'IN')) ----------------------------PhysicalOlapScan[customer_address] --------------PhysicalDistribute -----------------PhysicalProject -------------------hashAgg[GLOBAL] ---------------------PhysicalDistribute -----------------------hashAgg[LOCAL] -------------------------PhysicalDistribute ---------------------------PhysicalProject -----------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) +----------------hashAgg[GLOBAL] +------------------PhysicalDistribute +--------------------hashAgg[LOCAL] +----------------------PhysicalDistribute +------------------------PhysicalProject +--------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query51.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query51.out index 8ba49dc8d60bdd..b8d6435601a5d9 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query51.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query51.out @@ -14,30 +14,32 @@ PhysicalResultSink ----------------------PhysicalWindow ------------------------PhysicalQuickSort --------------------------PhysicalDistribute -----------------------------hashAgg[GLOBAL] -------------------------------PhysicalDistribute ---------------------------------hashAgg[LOCAL] -----------------------------------PhysicalProject -------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk = date_dim.d_date_sk) ---------------------------------------PhysicalProject -----------------------------------------PhysicalOlapScan[store_sales] ---------------------------------------PhysicalDistribute +----------------------------PhysicalProject +------------------------------hashAgg[GLOBAL] +--------------------------------PhysicalDistribute +----------------------------------hashAgg[LOCAL] +------------------------------------PhysicalProject +--------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk = date_dim.d_date_sk) ----------------------------------------PhysicalProject -------------------------------------------filter((date_dim.d_month_seq <= 1227)(date_dim.d_month_seq >= 1216)) ---------------------------------------------PhysicalOlapScan[date_dim] +------------------------------------------PhysicalOlapScan[store_sales] +----------------------------------------PhysicalDistribute +------------------------------------------PhysicalProject +--------------------------------------------filter((date_dim.d_month_seq <= 1227)(date_dim.d_month_seq >= 1216)) +----------------------------------------------PhysicalOlapScan[date_dim] --------------------PhysicalProject ----------------------PhysicalWindow ------------------------PhysicalQuickSort --------------------------PhysicalDistribute -----------------------------hashAgg[GLOBAL] -------------------------------PhysicalDistribute ---------------------------------hashAgg[LOCAL] -----------------------------------PhysicalProject -------------------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = date_dim.d_date_sk) ---------------------------------------PhysicalProject -----------------------------------------PhysicalOlapScan[web_sales] ---------------------------------------PhysicalDistribute +----------------------------PhysicalProject +------------------------------hashAgg[GLOBAL] +--------------------------------PhysicalDistribute +----------------------------------hashAgg[LOCAL] +------------------------------------PhysicalProject +--------------------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = date_dim.d_date_sk) ----------------------------------------PhysicalProject -------------------------------------------filter((date_dim.d_month_seq >= 1216)(date_dim.d_month_seq <= 1227)) ---------------------------------------------PhysicalOlapScan[date_dim] +------------------------------------------PhysicalOlapScan[web_sales] +----------------------------------------PhysicalDistribute +------------------------------------------PhysicalProject +--------------------------------------------filter((date_dim.d_month_seq >= 1216)(date_dim.d_month_seq <= 1227)) +----------------------------------------------PhysicalOlapScan[date_dim] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out index 77c7b273ba3d0f..bfcec6ce4127f4 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out @@ -24,7 +24,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------PhysicalDistribute --------PhysicalTopN ----------PhysicalProject -------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE)) --------------hashJoin[INNER_JOIN](ctr1.ctr_customer_sk = customer.c_customer_sk) ----------------PhysicalDistribute ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) @@ -38,11 +38,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------------------------filter((cast(ca_state as VARCHAR(*)) = 'CA')) ----------------------------PhysicalOlapScan[customer_address] --------------PhysicalDistribute -----------------PhysicalProject -------------------hashAgg[GLOBAL] ---------------------PhysicalDistribute -----------------------hashAgg[LOCAL] -------------------------PhysicalDistribute ---------------------------PhysicalProject -----------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) +----------------hashAgg[GLOBAL] +------------------PhysicalDistribute +--------------------hashAgg[LOCAL] +----------------------PhysicalDistribute +------------------------PhysicalProject +--------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out index 6114877bc9dd66..9913e27f5fb37b 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out @@ -9,13 +9,12 @@ PhysicalResultSink ------------PhysicalDistribute --------------PhysicalProject ----------------hashJoin[INNER_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity))) -------------------PhysicalProject ---------------------hashAgg[GLOBAL] -----------------------PhysicalDistribute -------------------------hashAgg[LOCAL] ---------------------------PhysicalProject -----------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01)) -------------------------------PhysicalOlapScan[lineitem] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute +----------------------hashAgg[LOCAL] +------------------------PhysicalProject +--------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01)) +----------------------------PhysicalOlapScan[lineitem] ------------------PhysicalDistribute --------------------hashJoin[LEFT_SEMI_JOIN](partsupp.ps_partkey = part.p_partkey) ----------------------PhysicalProject diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out index 6114877bc9dd66..9913e27f5fb37b 100644 --- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out +++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out @@ -9,13 +9,12 @@ PhysicalResultSink ------------PhysicalDistribute --------------PhysicalProject ----------------hashJoin[INNER_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity))) -------------------PhysicalProject ---------------------hashAgg[GLOBAL] -----------------------PhysicalDistribute -------------------------hashAgg[LOCAL] ---------------------------PhysicalProject -----------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01)) -------------------------------PhysicalOlapScan[lineitem] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute +----------------------hashAgg[LOCAL] +------------------------PhysicalProject +--------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01)) +----------------------------PhysicalOlapScan[lineitem] ------------------PhysicalDistribute --------------------hashJoin[LEFT_SEMI_JOIN](partsupp.ps_partkey = part.p_partkey) ----------------------PhysicalProject diff --git a/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy new file mode 100644 index 00000000000000..0521334d8ae881 --- /dev/null +++ b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_subquery_in_project") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql """drop table if exists test_sql;""" + sql """ + CREATE TABLE `test_sql` ( + `user_id` varchar(10) NULL, + `dt` date NULL, + `city` varchar(20) NULL, + `age` int(11) NULL + ) ENGINE=OLAP + UNIQUE KEY(`user_id`) + COMMENT 'test' + DISTRIBUTED BY HASH(`user_id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "is_being_synced" = "false", + "storage_format" = "V2", + "light_schema_change" = "true", + "disable_auto_compaction" = "false", + "enable_single_replica_compaction" = "false" + ); + """ + + sql """ insert into test_sql values (1,'2020-09-09',2,3);""" + + qt_sql1 """ + select (select age from test_sql) col from test_sql order by col; + """ + + qt_sql2 """ + select (select sum(age) from test_sql) col from test_sql order by col; + """ + + qt_sql3 """ + select (select sum(age) from test_sql t2 where t2.dt = t1.dt ) col from test_sql t1 order by col; + """ + + qt_sql4 """ + select age in (select user_id from test_sql) col from test_sql order by col; + """ + + qt_sql5 """ + select age in (select user_id from test_sql t2 where t2.user_id = t1.age) col from test_sql t1 order by col; + """ + + qt_sql6 """ + select exists ( select user_id from test_sql ) col from test_sql order by col; + """ + + qt_sql7 """ + select case when age in (select user_id from test_sql) or age in (select user_id from test_sql t2 where t2.user_id = t1.age) or exists ( select user_id from test_sql ) or exists ( select t2.user_id from test_sql t2 where t2.age = t1.user_id) or age < (select sum(age) from test_sql t2 where t2.dt = t1.dt ) then 2 else 1 end col from test_sql t1 order by col; + """ + + sql """ insert into test_sql values (2,'2020-09-09',2,1);""" + + try { + sql """ + select (select age from test_sql) col from test_sql order by col; + """ + } catch (Exception ex) { + assertTrue(ex.getMessage().contains("Expected EQ 1 to be returned by expression")) + } + + qt_sql8 """ + select (select sum(age) from test_sql) col from test_sql order by col; + """ + + qt_sql9 """ + select (select sum(age) from test_sql t2 where t2.dt = t1.dt ) col from test_sql t1 order by col; + """ + + qt_sql10 """ + select age in (select user_id from test_sql) col from test_sql order by col; + """ + + qt_sql11 """ + select age in (select user_id from test_sql t2 where t2.user_id = t1.age) col from test_sql t1 order by col; + """ + + qt_sql12 """ + select exists ( select user_id from test_sql ) col from test_sql order by col; + """ + + qt_sql13 """ + select case when age in (select user_id from test_sql) or age in (select user_id from test_sql t2 where t2.user_id = t1.age) or exists ( select user_id from test_sql ) or exists ( select t2.user_id from test_sql t2 where t2.age = t1.user_id) or age < (select sum(age) from test_sql t2 where t2.dt = t1.dt ) then 2 else 1 end col from test_sql t1 order by col; + """ + + qt_sql14 """ + select dt,case when 'med'='med' then ( + select sum(midean) from ( + select sum(score) / count(*) as midean + from ( + select age score,row_number() over (order by age desc) as desc_math, + row_number() over (order by age asc) as asc_math from test_sql + ) as order_table + where asc_math in (desc_math, desc_math + 1, desc_math - 1)) m + ) + end 'test' from test_sql group by cube(dt) order by dt; + """ + + sql """drop table if exists test_sql;""" +}