Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,26 @@ public List<Rule> buildRules() {
LogicalAggregate<GroupPlan> agg = ctx.root;
List<NamedExpression> output =
bind(agg.getOutputExpressions(), agg.children(), agg, ctx.cascadesContext);

// The columns referenced in group by are first obtained from the child's output,
// and then from the node's output
Map<String, Expression> childOutputsToExpr = agg.child().getOutput().stream()
.collect(Collectors.toMap(Slot::getName, Slot::toSlot, (oldExpr, newExpr) -> oldExpr));
Map<String, Expression> aliasNameToExpr = output.stream()
.filter(ne -> ne instanceof Alias)
.map(Alias.class::cast)
.collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr));
aliasNameToExpr.entrySet().stream()
.forEach(e -> childOutputsToExpr.putIfAbsent(e.getKey(), e.getValue()));

List<Expression> replacedGroupBy = agg.getGroupByExpressions().stream()
.map(groupBy -> {
if (groupBy instanceof UnboundSlot) {
UnboundSlot unboundSlot = (UnboundSlot) groupBy;
if (unboundSlot.getNameParts().size() == 1) {
String name = unboundSlot.getNameParts().get(0);
if (aliasNameToExpr.containsKey(name)) {
return aliasNameToExpr.get(name);
if (childOutputsToExpr.containsKey(name)) {
return childOutputsToExpr.get(name);
}
}
}
Expand All @@ -197,19 +205,26 @@ public List<Rule> buildRules() {
List<NamedExpression> output =
bind(repeat.getOutputExpressions(), repeat.children(), repeat, ctx.cascadesContext);

// The columns referenced in group by are first obtained from the child's output,
// and then from the node's output
Map<String, Expression> childOutputsToExpr = repeat.child().getOutput().stream()
.collect(Collectors.toMap(Slot::getName, Slot::toSlot, (oldExpr, newExpr) -> oldExpr));
Map<String, Expression> aliasNameToExpr = output.stream()
.filter(ne -> ne instanceof Alias)
.map(Alias.class::cast)
.collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr));
aliasNameToExpr.entrySet().stream()
.forEach(e -> childOutputsToExpr.putIfAbsent(e.getKey(), e.getValue()));

List<List<Expression>> replacedGroupingSets = repeat.getGroupingSets().stream()
.map(groupBy ->
groupBy.stream().map(expr -> {
if (expr instanceof UnboundSlot) {
UnboundSlot unboundSlot = (UnboundSlot) expr;
if (unboundSlot.getNameParts().size() == 1) {
String name = unboundSlot.getNameParts().get(0);
if (aliasNameToExpr.containsKey(name)) {
return aliasNameToExpr.get(name);
if (childOutputsToExpr.containsKey(name)) {
return childOutputsToExpr.get(name);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ private Set<Expression> collectNeedToSlotExpressions(LogicalRepeat<Plan> repeat)
}

private Plan pushDownProject(Set<NamedExpression> pushedExprs, Plan originBottomPlan) {
if (!pushedExprs.equals(originBottomPlan.getOutputSet())) {
if (!pushedExprs.equals(originBottomPlan.getOutputSet()) && !pushedExprs.isEmpty()) {
return new LogicalProject<>(ImmutableList.copyOf(pushedExprs), originBottomPlan);
}
return originBottomPlan;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ public Rule build() {
aggregateFunctionToSlotContext.pushDownToNamedExpression(normalizedAggregateFunctions);

List<Slot> normalizedGroupBy =
(List) groupByAndArgumentToSlotContext.normalizeToUseSlotRef(aggregate.getGroupByExpressions());
(List) groupByAndArgumentToSlotContext
.normalizeToUseSlotRef(aggregate.getGroupByExpressions());

// we can safely add all groupBy and aggregate functions to output, because we will
// add a project on it, and the upper project can protect the scope of visible of slot
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,53 @@ public void test() {
PlanChecker.from(connectContext)
.checkPlannerResult("select if(k1 = 1, 2, k1) k_if from t1");
}

@Test
public void test1() {
PlanChecker.from(connectContext)
.checkPlannerResult("select coalesce(col1, 'all') as col1, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());");
}

@Test
public void test1_1() {
PlanChecker.from(connectContext)
.checkPlannerResult("select coalesce(col1, 'all') as col2, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());");
}

@Test
public void test1_2() {
PlanChecker.from(connectContext)
.checkPlannerResult("select coalesce(col1, 'all') as col2, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),());");
}

@Test
public void test2() {
PlanChecker.from(connectContext)
.checkPlannerResult("select if(1 = null, 'all', 2) as col1, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());");
}

@Test
public void test2_1() {
PlanChecker.from(connectContext)
.checkPlannerResult("select if(col1 = null, 'all', 2) as col1, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());");
}

@Test
public void test2_2() {
PlanChecker.from(connectContext)
.checkPlannerResult("select if(col1 = null, 'all', 2) as col2, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());");
}

@Test
public void test2_3() {
PlanChecker.from(connectContext)
.checkPlannerResult("select if(col1 = null, 'all', 2) as col2, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),());");
}
}
34 changes: 34 additions & 0 deletions regression-test/data/nereids_syntax_p0/grouping_sets.out
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,37 @@
3
4

-- !select1 --
a 1
all 1
all 2

-- !select2 --
a 1
all 1
all 2

-- !select3 --
\N 2
a 1
all 1

-- !select4 --
2 1
2 1
2 2

-- !select5 --
2 1
2 1
2 2

-- !select6 --
2 1
2 1
2 2

-- !select7 --
2 1
2 1
2 2
28 changes: 28 additions & 0 deletions regression-test/suites/nereids_syntax_p0/grouping_sets.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -232,4 +232,32 @@ suite("test_nereids_grouping_sets") {
) T
) T2;
"""

order_qt_select1 """
select coalesce(col1, 'all') as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());
"""

order_qt_select2 """
select coalesce(col1, 'all') as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());
"""

order_qt_select3 """
select coalesce(col1, 'all') as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),());
"""

order_qt_select4 """
select if(1 = null, 'all', 2) as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());
"""

order_qt_select5 """
select if(col1 = null, 'all', 2) as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());
"""

order_qt_select6 """
select if(1 = null, 'all', 2) as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());
"""

order_qt_select7 """
select if(col1 = null, 'all', 2) as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());
"""
}