Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,9 @@ public class Rewriter extends AbstractBatchJobExecutor {

topic("Eliminate GroupBy",
topDown(new EliminateGroupBy(),
new MergeAggregate())
new MergeAggregate(),
// need to adjust min/max/sum nullable attribute after merge aggregate
new AdjustAggregateNullableForEmptySet())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment here to explain why need execute AdjustAggregateNullableForEmptySet after MergeAggregate

),

topic("Eager aggregation",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -87,15 +89,14 @@ private Plan mergeTwoAggregate(LogicalAggregate<LogicalAggregate<Plan>> outerAgg
private Plan mergeAggProjectAgg(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) {
LogicalProject<LogicalAggregate<Plan>> project = outerAgg.child();
LogicalAggregate<Plan> innerAgg = project.child();

List<NamedExpression> outputExpressions = outerAgg.getOutputExpressions();
List<NamedExpression> replacedOutputExpressions = PlanUtils.replaceExpressionByProjections(
project.getProjects(), (List) outputExpressions);
// rewrite agg function. e.g. max(max)
List<NamedExpression> aggFunc = outerAgg.getOutputExpressions().stream()
List<NamedExpression> replacedAggFunc = replacedOutputExpressions.stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
.collect(Collectors.toList());
// rewrite agg function directly refer to the slot below the project
List<Expression> replacedAggFunc = PlanUtils.replaceExpressionByProjections(project.getProjects(),
(List) aggFunc);
// replace groupByKeys directly refer to the slot below the project
List<Expression> replacedGroupBy = PlanUtils.replaceExpressionByProjections(project.getProjects(),
outerAgg.getGroupByExpressions());
Expand Down Expand Up @@ -138,13 +139,17 @@ private NamedExpression rewriteAggregateFunction(NamedExpression e,
}

boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, LogicalAggregate<Plan> innerAgg,
boolean sameGroupBy) {
boolean sameGroupBy, Optional<LogicalProject> projectOptional) {
innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
(existValue, newValue) -> existValue));
Set<AggregateFunction> aggregateFunctions = outerAgg.getAggregateFunctions();
for (AggregateFunction outerFunc : aggregateFunctions) {
List<AggregateFunction> replacedAggFunctions = projectOptional.map(project ->
(List<AggregateFunction>) PlanUtils.replaceExpressionByProjections(
projectOptional.get().getProjects(), new ArrayList<>(aggregateFunctions)))
.orElse(new ArrayList<>(aggregateFunctions));
for (AggregateFunction outerFunc : replacedAggFunctions) {
if (!(ALLOW_MERGE_AGGREGATE_FUNCTIONS.contains(outerFunc.getName()))) {
return false;
}
Expand Down Expand Up @@ -188,7 +193,7 @@ private boolean canMergeAggregateWithoutProject(LogicalAggregate<LogicalAggregat
}
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());

return commonCheck(outerAgg, innerAgg, sameGroupBy);
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.empty());
}

private boolean canMergeAggregateWithProject(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) {
Expand All @@ -206,6 +211,6 @@ private boolean canMergeAggregateWithProject(LogicalAggregate<LogicalProject<Log
return false;
}
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
return commonCheck(outerAgg, innerAgg, sameGroupBy);
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,54 @@ PhysicalResultSink
--------------------PhysicalProject
----------------------PhysicalOlapScan[mal_test1]

-- !test_has_project_distinct_cant_transform --
1

-- !test_has_project_distinct_cant_transform_shape --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------PhysicalProject
----------hashAgg[GLOBAL]
------------PhysicalDistribute[DistributionSpecHash]
--------------hashAgg[LOCAL]
----------------PhysicalProject
------------------PhysicalOlapScan[mal_test_merge_agg]

-- !test_distinct_expr_transform --
-1

-- !test_distinct_expr_transform_shape --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------PhysicalProject
----------PhysicalOlapScan[mal_test_merge_agg]

-- !test_has_project_distinct_expr_transform --
1
1
1

-- !test_has_project_distinct_expr_transform --
PhysicalResultSink
--PhysicalDistribute[DistributionSpecGather]
----PhysicalProject
------hashAgg[GLOBAL]
--------PhysicalDistribute[DistributionSpecHash]
----------hashAgg[LOCAL]
------------PhysicalProject
--------------PhysicalOlapScan[mal_test_merge_agg]

-- !test_sum_empty_table --
\N \N \N

-- !test_sum_empty_table_shape --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------PhysicalOlapScan[mal_test2]

Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,84 @@ suite("merge_aggregate") {
group by a order by 1,2;
"""

sql "drop table if exists mal_test_merge_agg"
sql """
create table mal_test_merge_agg(
k1 int null,
k2 int not null,
k3 string null,
k4 varchar(100) null
)
duplicate key (k1,k2)
distributed BY hash(k1) buckets 3
properties("replication_num" = "1");
"""
sql "insert into mal_test_merge_agg select 1,1,'1','a';"
sql "insert into mal_test_merge_agg select 2,2,'2','b';"
sql "insert into mal_test_merge_agg select 3,-3,null,'c';"
sql "sync"

qt_test_has_project_distinct_cant_transform """
select max(count_col)
from (
select k4,
count(distinct case when k3 is null then 1 else 0 end) as count_col
from mal_test_merge_agg group by k4
) t ;
"""
qt_test_has_project_distinct_cant_transform_shape """
explain shape plan
select max(count_col)
from (
select k4,
count(distinct case when k3 is null then 1 else 0 end) as count_col
from mal_test_merge_agg group by k4
) t ;
"""

qt_test_distinct_expr_transform """
select max(count_col)
from (
select k4,
max(-abs(k1)) as count_col
from mal_test_merge_agg group by k4
) t ;
"""
qt_test_distinct_expr_transform_shape """
explain shape plan
select max(count_col)
from (
select k4,
max(-abs(k1)) as count_col
from mal_test_merge_agg group by k4
) t ;
"""

qt_test_has_project_distinct_expr_transform """
select sum(count_col)
from (
select k4,
count(distinct case when k3 is null then 1 else 0 end) as count_col
from mal_test_merge_agg group by k4
) t group by k4;
"""

qt_test_has_project_distinct_expr_transform """
explain shape plan
select sum(count_col)
from (
select k4,
count(distinct case when k3 is null then 1 else 0 end) as count_col
from mal_test_merge_agg group by k4
) t group by k4;
"""

qt_test_sum_empty_table """
select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t;
"""

qt_test_sum_empty_table_shape """
explain shape plan
select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t;
"""
}