Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -101,23 +101,94 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
@Override
public Rule build() {
return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
// The LogicalAggregate node may contain window agg functions and usual agg functions
// we call window agg functions as window-agg and usual agg functions as trival-agg for short
// This rule simplify LogicalAggregate node by:
// 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node,
// 2. create a new LogicalAggregate with normalized group by exprs and trival-aggs
// 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node
// Push down exprs:
// 1. all group by exprs
// 2. child contains subquery expr in trival-agg
// 3. child contains window expr in trival-agg
// 4. all input slots of trival-agg
// 5. expr(including subquery) in distinct trival-agg
// Normalize LogicalAggregate's output.
// 1. normalize group by exprs by outputs of bottom LogicalProject
// 2. normalize trival-aggs by outputs of bottom LogicalProject
// 3. build normalized agg outputs
// Pull up exprs:
// normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes:
// 1. simple slots
// 2. aliases
// a. alias with no aggs child
// b. alias with trival-agg child
// c. alias with window-agg

List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
Set<Alias> existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
// Push down exprs:
// collect group by exprs
Set<Expression> groupingByExprs =
ImmutableSet.copyOf(aggregate.getGroupByExpressions());

// collect all trival-agg
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
List<AggregateFunction> aggFuncs = Lists.newArrayList();
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));

// we need push down subquery exprs inside non-window and non-distinct agg functions
// because the distinct agg's children would be pushed down in later process
Set<SubqueryExpr> subqueryExprs = ExpressionUtils.mutableCollect(aggFuncs.stream()
.filter(aggFunc -> !aggFunc.isDistinct()).collect(Collectors.toList()),
SubqueryExpr.class::isInstance);
Set<Expression> groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions());
// split non-distinct agg child as two part
// TRUE part 1: need push down itself, if it contains subqury or window expression
// FALSE part 2: need push down its input slots, if it DOES NOT contain subqury or window expression
Map<Boolean, Set<Expression>> categorizedNoDistinctAggsChildren = aggFuncs.stream()
.filter(aggFunc -> !aggFunc.isDistinct())
.flatMap(agg -> agg.children().stream())
.collect(Collectors.groupingBy(
child -> child.containsType(SubqueryExpr.class, WindowExpression.class),
Collectors.toSet()));

// split distinct agg child as two parts
// TRUE part 1: need push down itself, if it is NOT SlotReference or Literal
// FALSE part 2: need push down its input slots, if it is SlotReference or Literal
Map<Boolean, Set<Expression>> categorizedDistinctAggsChildren = aggFuncs.stream()
.filter(aggFunc -> aggFunc.isDistinct()).flatMap(agg -> agg.children().stream())
.collect(Collectors.groupingBy(
child -> !(child instanceof SlotReference || child instanceof Literal),
Collectors.toSet()));

Set<Expression> needPushSelf = Sets.union(
categorizedNoDistinctAggsChildren.getOrDefault(true, new HashSet<>()),
categorizedDistinctAggsChildren.getOrDefault(true, new HashSet<>()));
Set<Slot> needPushInputSlots = ExpressionUtils.getInputSlotSet(Sets.union(
categorizedNoDistinctAggsChildren.getOrDefault(false, new HashSet<>()),
categorizedDistinctAggsChildren.getOrDefault(false, new HashSet<>())));

Set<Alias> existsAlias =
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);

// push down 3 kinds of exprs, these pushed exprs will be used to normalize agg output later
// 1. group by exprs
// 2. trivalAgg children
// 3. trivalAgg input slots
Set<Expression> allPushDownExprs =
Sets.union(groupingByExprs, Sets.union(needPushSelf, needPushInputSlots));
NormalizeToSlotContext bottomSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, Sets.union(groupingByExprs, subqueryExprs));
Set<NamedExpression> bottomOutputs =
bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, subqueryExprs));
NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs);
Set<NamedExpression> pushedGroupByExprs =
bottomSlotContext.pushDownToNamedExpression(groupingByExprs);
Set<NamedExpression> pushedTrivalAggChildren =
bottomSlotContext.pushDownToNamedExpression(needPushSelf);
Set<NamedExpression> pushedTrivalAggInputSlots =
bottomSlotContext.pushDownToNamedExpression(needPushInputSlots);
Set<NamedExpression> bottomProjects = Sets.union(pushedGroupByExprs,
Sets.union(pushedTrivalAggChildren, pushedTrivalAggInputSlots));

// create bottom project
Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects),
aggregate.child());
} else {
bottomPlan = aggregate.child();
}

// use group by context to normalize agg functions to process
// sql like: select sum(a + 1) from t group by a + 1
Expand All @@ -129,89 +200,37 @@ public Rule build() {
// after normalize:
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1])
// +-- project((a[#0] + 1)[#1])
List<AggregateFunction> normalizedAggFuncs = bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
Set<NamedExpression> bottomProjects = Sets.newHashSet(bottomOutputs);
// TODO: if we have distinct agg, we must push down its children,
// because need use it to generate distribution enforce
// step 1: split agg functions into 2 parts: distinct and not distinct
List<AggregateFunction> distinctAggFuncs = Lists.newArrayList();
List<AggregateFunction> nonDistinctAggFuncs = Lists.newArrayList();
for (AggregateFunction aggregateFunction : normalizedAggFuncs) {
if (aggregateFunction.isDistinct()) {
distinctAggFuncs.add(aggregateFunction);
} else {
nonDistinctAggFuncs.add(aggregateFunction);
}
}
// step 2: if we only have one distinct agg function, we do push down for it
if (!distinctAggFuncs.isEmpty()) {
// process distinct normalize and put it back to normalizedAggFuncs
List<AggregateFunction> newDistinctAggFuncs = Lists.newArrayList();
Map<Expression, Expression> replaceMap = Maps.newHashMap();
Map<Expression, NamedExpression> aliasCache = Maps.newHashMap();
for (AggregateFunction distinctAggFunc : distinctAggFuncs) {
List<Expression> newChildren = Lists.newArrayList();
for (Expression child : distinctAggFunc.children()) {
if (child instanceof SlotReference || child instanceof Literal) {
newChildren.add(child);
} else {
NamedExpression alias;
if (aliasCache.containsKey(child)) {
alias = aliasCache.get(child);
} else {
alias = new Alias(child);
aliasCache.put(child, alias);
}
bottomProjects.add(alias);
newChildren.add(alias.toSlot());
}
}
AggregateFunction newDistinctAggFunc = distinctAggFunc.withChildren(newChildren);
replaceMap.put(distinctAggFunc, newDistinctAggFunc);
newDistinctAggFuncs.add(newDistinctAggFunc);
}
aggregateOutput = aggregateOutput.stream()
.map(e -> ExpressionUtils.replace(e, replaceMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
distinctAggFuncs = newDistinctAggFuncs;
}
normalizedAggFuncs = Lists.newArrayList(nonDistinctAggFuncs);
normalizedAggFuncs.addAll(distinctAggFuncs);
// TODO: process redundant expressions in aggregate functions children

// normalize group by exprs by bottomProjects
List<Expression> normalizedGroupExprs =
bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);

// normalize trival-aggs by bottomProjects
List<AggregateFunction> normalizedAggFuncs =
bottomSlotContext.normalizeToUseSlotRef(aggFuncs);

// build normalized agg output
NormalizeToSlotContext normalizedAggFuncsToSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs);
// agg output include 2 part, normalized group by slots and normalized agg functions

// agg output include 2 parts
// pushedGroupByExprs and normalized agg functions
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(bottomOutputs.stream().map(NamedExpression::toSlot).iterator())
.addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs))
.addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator())
.addAll(normalizedAggFuncsToSlotContext
.pushDownToNamedExpression(normalizedAggFuncs))
.build();
// add normalized agg's input slots to bottom projects
Set<Slot> bottomProjectSlots = bottomProjects.stream()
.map(NamedExpression::toSlot)
.collect(Collectors.toSet());
Set<NamedExpression> aggInputSlots = normalizedAggFuncs.stream()
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(e -> !bottomProjectSlots.contains(e))
.collect(Collectors.toSet());
bottomProjects.addAll(aggInputSlots);
// build group by exprs
List<Expression> normalizedGroupExprs = bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);

Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child());
} else {
bottomPlan = aggregate.child();
}
// create new agg node
LogicalAggregate newAggregate =
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan);

// create upper projects by normalize all output exprs in old LogicalAggregate
List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
bottomSlotContext, normalizedAggFuncsToSlotContext);

return new LogicalProject<>(upperProjects,
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan));
// create a parent project node
return new LogicalProject<>(upperProjects, newAggregate);
}).toRule(RuleType.NORMALIZE_AGGREGATE);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

suite("agg_distinct_case_when") {
sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"
sql "DROP TABLE IF EXISTS agg_test_table_t;"
sql """
CREATE TABLE `agg_test_table_t` (
`k1` varchar(65533) NULL,
`k2` text NULL,
`k3` text null,
`k4` text null
) ENGINE=OLAP
DUPLICATE KEY(`k1`)
COMMENT 'OLAP'
DISTRIBUTED BY HASH(`k1`) BUCKETS 10
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"is_being_synced" = "false",
"storage_format" = "V2",
"light_schema_change" = "true",
"disable_auto_compaction" = "false",
"enable_single_replica_compaction" = "false"
);
"""

sql """insert into agg_test_table_t(`k1`,`k2`,`k3`) values('20231026221524','PA','adigu1bububud');"""
sql """
select
count(distinct case when t.k2='PA' and loan_date=to_date(substr(t.k1,1,8)) then t.k2 end )
from (
select substr(k1,1,8) loan_date,k3,k2,k1 from agg_test_table_t) t
group by
substr(t.k1,1,8);"""

sql "DROP TABLE IF EXISTS agg_test_table_t;"
}
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,8 @@ suite("test_window_fn") {
"storage_format" = "V2"
);
"""
test {
sql """SELECT SUM(MAX(c1) OVER (PARTITION BY c2, c3)) FROM test_window_in_agg;"""
exception "errCode = 2, detailMessage = AGGREGATE clause must not contain analytic expressions"
}
sql """set enable_nereids_planner=true;"""
sql """SELECT SUM(MAX(c1) OVER (PARTITION BY c2, c3)) FROM test_window_in_agg;"""
sql "DROP TABLE IF EXISTS test_window_in_agg;"
}

Expand Down