Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.doris.nereids.rules.analysis.CollectJoinConstraint;
import org.apache.doris.nereids.rules.analysis.CollectSubQueryAlias;
import org.apache.doris.nereids.rules.analysis.CompressedMaterialize;
import org.apache.doris.nereids.rules.analysis.EliminateDistinctConstant;
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.EliminateLogicalSelectHint;
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
Expand All @@ -43,7 +42,6 @@
import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
import org.apache.doris.nereids.rules.analysis.OneRowRelationExtractAggregate;
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate;
import org.apache.doris.nereids.rules.analysis.QualifyToFilter;
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
import org.apache.doris.nereids.rules.analysis.SubqueryToApply;
Expand Down Expand Up @@ -110,13 +108,6 @@ private static List<RewriteJob> buildAnalyzerJobs() {
topDown(new FillUpQualifyMissingSlot()),
bottomUp(
new ProjectToGlobalAggregate(),
// this rule check's the logicalProject node's isDistinct property
// and replace the logicalProject node with a LogicalAggregate node
// so any rule before this, if create a new logicalProject node
// should make sure isDistinct property is correctly passed around.
// please see rule BindSlotReference or BindFunction for example
new EliminateDistinctConstant(),
new ProjectWithDistinctToAggregate(),
new ReplaceExpressionByChildOutput(),
new OneRowRelationExtractAggregate()
),
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,24 @@

package org.apache.doris.nereids.rules.analysis;

import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitors;
import org.apache.doris.nereids.trees.plans.LimitPhase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ProjectToGlobalAggregate.
* <p>
Expand All @@ -43,17 +54,110 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return RuleType.PROJECT_TO_GLOBAL_AGGREGATE.build(
logicalProject().then(project -> {
boolean needGlobalAggregate = project.getProjects()
.stream()
.anyMatch(p -> p.accept(ExpressionVisitors.CONTAINS_AGGREGATE_CHECKER, null));

if (needGlobalAggregate) {
return new LogicalAggregate<>(ImmutableList.of(), project.getProjects(), project.child());
} else {
return project;
}
})
logicalProject().then(project -> {
project = distinctConstantsToLimit1(project);
Plan result = projectToAggregate(project);
return distinctToAggregate(result, project);
})
);
}

// select distinct 1,2,3 from tbl
// ↓
// select 1,2,3 from (select 1, 2, 3 from tbl limit 1) as tmp
private static LogicalProject<Plan> distinctConstantsToLimit1(LogicalProject<Plan> project) {
if (!project.isDistinct()) {
return project;
}

boolean allSelectItemAreConstants = true;
for (NamedExpression selectItem : project.getProjects()) {
if (!selectItem.isConstant()) {
allSelectItemAreConstants = false;
break;
}
}

if (allSelectItemAreConstants) {
return new LogicalProject<>(
project.getProjects(),
new LogicalLimit<>(1, 0, LimitPhase.ORIGIN, project.child())
);
}
return project;
}

// select avg(xxx) from tbl
// ↓
// LogicalAggregate(groupBy=[], output=[avg(xxx)])
private static Plan projectToAggregate(LogicalProject<Plan> project) {
// contains aggregate functions, like sum, avg ?
for (NamedExpression selectItem : project.getProjects()) {
if (selectItem.accept(ExpressionVisitors.CONTAINS_AGGREGATE_CHECKER, null)) {
return new LogicalAggregate<>(ImmutableList.of(), project.getProjects(), project.child());
}
}
return project;
}

private static Plan distinctToAggregate(Plan result, LogicalProject<Plan> originProject) {
if (!originProject.isDistinct()) {
return result;
}
if (result instanceof LogicalProject) {
// remove distinct: select distinct fun(xxx) as c1 from tbl
//
// LogicalProject(distinct=true, output=[fun(xxx) as c1])
// ↓
// LogicalAggregate(groupBy=[c1], output=[c1])
// |
// LogicalProject(output=[fun(xxx) as c1])
LogicalProject<?> project = (LogicalProject<?>) result;

ImmutableList.Builder<NamedExpression> bottomProjectOutput
= ImmutableList.builderWithExpectedSize(project.getProjects().size());
ImmutableList.Builder<NamedExpression> topAggOutput
= ImmutableList.builderWithExpectedSize(project.getProjects().size());

boolean hasComplexExpr = false;
for (NamedExpression selectItem : project.getProjects()) {
if (selectItem.isSlot()) {
topAggOutput.add(selectItem);
bottomProjectOutput.add(selectItem);
} else if (isAliasLiteral(selectItem)) {
// stay in agg, and eliminate by `ELIMINATE_GROUP_BY_CONSTANT`
topAggOutput.add(selectItem);
} else {
// `FillUpMissingSlots` not support find complex expr in aggregate,
// so we should push down into the bottom project
hasComplexExpr = true;
topAggOutput.add(selectItem.toSlot());
bottomProjectOutput.add(selectItem);
}
}

if (!hasComplexExpr) {
List<Slot> projects = (List) project.getProjects();
return new LogicalAggregate(projects, projects, project.child());
}

LogicalProject<?> removeDistinct = new LogicalProject<>(bottomProjectOutput.build(), project.child());
ImmutableList<NamedExpression> aggOutput = topAggOutput.build();
return new LogicalAggregate(aggOutput, aggOutput, removeDistinct);
} else if (result instanceof LogicalAggregate) {
// remove distinct: select distinct avg(xxx) as c1 from tbl
//
// LogicalProject(distinct=true, output=[avg(xxx) as c1])
// ↓
// LogicalAggregate(output=[avg(xxx) as c1])
return result;
} else {
// never reach
throw new AnalysisException("Unsupported");
}
}

private static boolean isAliasLiteral(NamedExpression selectItem) {
return selectItem instanceof Alias && selectItem.child(0) instanceof Literal;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,27 @@ public List<Rule> buildRules() {
))
.add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
logicalSort(logicalAggregate()).then(sort -> {
LogicalAggregate<Plan> aggregate = sort.child();
Map<Expression, Slot> sMap = buildOutputAliasMap(aggregate.getOutputExpressions());
LogicalAggregate<Plan> agg = sort.child();
Map<Expression, Slot> sMap = buildOutputAliasMap(agg.getOutputExpressions());
if (sMap.isEmpty() && isSelectDistinct(agg)) {
sMap = getSelectDistinctExpressions(agg);
}
return replaceSortExpression(sort, sMap);
})
)).add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
logicalSort(logicalHaving(logicalAggregate())).then(sort -> {
LogicalAggregate<Plan> aggregate = sort.child().child();
Map<Expression, Slot> sMap = buildOutputAliasMap(aggregate.getOutputExpressions());
LogicalAggregate<Plan> agg = sort.child().child();
Map<Expression, Slot> sMap = buildOutputAliasMap(agg.getOutputExpressions());
if (sMap.isEmpty() && isSelectDistinct(agg)) {
sMap = getSelectDistinctExpressions(agg);
}
return replaceSortExpression(sort, sMap);
})
))
.build();
}

private Map<Expression, Slot> buildOutputAliasMap(List<NamedExpression> output) {
private static Map<Expression, Slot> buildOutputAliasMap(List<NamedExpression> output) {
Map<Expression, Slot> sMap = Maps.newHashMapWithExpectedSize(output.size());
for (NamedExpression expr : output) {
if (expr instanceof Alias) {
Expand All @@ -93,4 +99,22 @@ private LogicalPlan replaceSortExpression(LogicalSort<? extends LogicalPlan> sor

return changed ? new LogicalSort<>(newKeys.build(), sort.child()) : sort;
}

private static boolean isSelectDistinct(LogicalAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().equals(agg.getOutputExpressions())
&& agg.getGroupByExpressions().equals(agg.child().getOutput());
}

private static Map<Expression, Slot> getSelectDistinctExpressions(LogicalAggregate<? extends Plan> agg) {
Plan child = agg.child();
List<NamedExpression> selectItems;
if (child instanceof LogicalProject) {
selectItems = ((LogicalProject<?>) child).getProjects();
} else if (child instanceof LogicalAggregate) {
selectItems = ((LogicalAggregate<?>) child).getOutputExpressions();
} else {
selectItems = ImmutableList.of();
}
return buildOutputAliasMap(selectItems);
}
}
48 changes: 48 additions & 0 deletions regression-test/suites/nereids_p0/aggregate/select_distinct.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

suite("select_distinct") {
multi_sql """
SET enable_nereids_planner=true;
SET enable_fallback_to_original_planner=false;
drop table if exists test_distinct_window;
create table test_distinct_window(id int) distributed by hash(id) properties('replication_num'='1');
insert into test_distinct_window values(1), (2), (3);
"""

test {
sql "select distinct sum(value) over(partition by id) from (select 100 value, 1 id union all select 100, 2)a"
result([[100L]])
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add more case to cover all if branch in rules, such as

  • qualify with alias
  • qualify with new window function
  • qualify with window function which has exsited in project list
  • having

test {
sql "select distinct value+1 from (select 100 value, 1 id union all select 100, 2)a order by value+1"
result([[101]])
}

test {
sql "select distinct 1, 2, 3 from test_distinct_window"
result([[1, 2, 3]])
}

test {
sql "select distinct sum(id) from test_distinct_window"
result([[6L]])
}
}
Loading