Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ private Plan bindSortWithoutSetOperation(MatchingContext<LogicalSort<Plan>> ctx)
final Plan finalInput = input;
Supplier<Scope> inputChildrenScope = Suppliers.memoize(
() -> toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(finalInput.children())));
SimpleExprAnalyzer analyzer = buildCustomSlotBinderAnalyzer(
SimpleExprAnalyzer bindInInputScopeThenInputChildScope = buildCustomSlotBinderAnalyzer(
sort, cascadesContext, inputScope, true, false,
(self, unboundSlot) -> {
// first, try to bind slot in Scope(input.output)
Expand All @@ -774,9 +774,17 @@ private Plan bindSortWithoutSetOperation(MatchingContext<LogicalSort<Plan>> ctx)
return self.bindExactSlotsByThisScope(unboundSlot, inputChildrenScope.get());
});

SimpleExprAnalyzer bindInInputChildScope = getAnalyzerForOrderByAggFunc(finalInput, cascadesContext, sort,
inputChildrenScope, inputScope);
Builder<OrderKey> boundOrderKeys = ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size());
FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry();
for (OrderKey orderKey : sort.getOrderKeys()) {
Expression boundKey = bindWithOrdinal(orderKey.getExpr(), analyzer, childOutput);
Expression boundKey;
if (hasAggregateFunction(orderKey.getExpr(), functionRegistry)) {
boundKey = bindInInputChildScope.analyze(orderKey.getExpr());
} else {
boundKey = bindWithOrdinal(orderKey.getExpr(), bindInInputScopeThenInputChildScope, childOutput);
}
boundOrderKeys.add(orderKey.withExpression(boundKey));
}
return new LogicalSort<>(boundOrderKeys.build(), sort.child());
Expand Down Expand Up @@ -965,4 +973,49 @@ default <E extends Expression> Set<E> analyzeToSet(List<E> exprs) {
private interface CustomSlotBinderAnalyzer {
List<? extends Expression> bindSlot(ExpressionAnalyzer analyzer, UnboundSlot unboundSlot);
}

private boolean hasAggregateFunction(Expression expression, FunctionRegistry functionRegistry) {
return expression.anyMatch(expr -> {
if (expr instanceof AggregateFunction) {
return true;
} else if (expr instanceof UnboundFunction) {
UnboundFunction unboundFunction = (UnboundFunction) expr;
boolean isAggregateFunction = functionRegistry
.isAggregateFunction(
unboundFunction.getDbName(),
unboundFunction.getName()
);
return isAggregateFunction;
}
return false;
});
}

private SimpleExprAnalyzer getAnalyzerForOrderByAggFunc(Plan finalInput, CascadesContext cascadesContext,
LogicalSort<Plan> sort, Supplier<Scope> inputChildrenScope, Scope inputScope) {
ImmutableList.Builder<Slot> outputSlots = ImmutableList.builder();
if (finalInput instanceof LogicalAggregate) {
LogicalAggregate<Plan> aggregate = (LogicalAggregate<Plan>) finalInput;
List<NamedExpression> outputExpressions = aggregate.getOutputExpressions();
for (NamedExpression outputExpr : outputExpressions) {
if (!outputExpr.anyMatch(expr -> expr instanceof AggregateFunction)) {
outputSlots.add(outputExpr.toSlot());
}
}
}
Scope outputWithoutAggFunc = toScope(cascadesContext, outputSlots.build());
SimpleExprAnalyzer bindInInputChildScope = buildCustomSlotBinderAnalyzer(
sort, cascadesContext, inputScope, true, false,
(analyzer, unboundSlot) -> {
if (finalInput instanceof LogicalAggregate) {
List<Slot> boundInOutputWithoutAggFunc = analyzer.bindSlotByScope(unboundSlot,
outputWithoutAggFunc);
if (!boundInOutputWithoutAggFunc.isEmpty()) {
return ImmutableList.of(boundInOutputWithoutAggFunc.get(0));
}
}
return analyzer.bindExactSlotsByThisScope(unboundSlot, inputChildrenScope.get());
});
return bindInInputChildScope;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ public void resolve(Expression expression) {
// We couldn't find the equivalent expression in output expressions and group-by expressions,
// so we should check whether the expression is valid.
if (expression instanceof SlotReference) {
throw new AnalysisException(expression.toSql() + " in having clause should be grouped by.");
throw new AnalysisException(expression.toSql() + " should be grouped by.");
} else if (expression instanceof AggregateFunction) {
if (checkWhetherNestedAggregateFunctionsExist((AggregateFunction) expression)) {
throw new AnalysisException("Aggregate functions in having clause can't be nested: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ void testJoinWithHaving() {
void testInvalidHaving() {
ExceptionChecker.expectThrowsWithMsg(
AnalysisException.class,
"a2 in having clause should be grouped by.",
"a2 should be grouped by.",
() -> PlanChecker.from(connectContext).analyze(
"SELECT a1 FROM t1 GROUP BY a1 HAVING a2 > 0"
));
Expand Down
25 changes: 25 additions & 0 deletions regression-test/data/nereids_syntax_p0/order_by_bind_priority.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !test_bind_order_by_with_aggfun1 --
10 -5 -10
4 -2 -4
4 1 3
6 3 6

-- !test_bind_order_by_with_aggfun2 --
4 -2 -4
10 -5 0
4 1 4
6 3 6

-- !test_bind_order_by_with_aggfun3 --
5 -5 5
2 -2 -2
2 1 4
3 3 3

-- !test_bind_order_by_in_no_agg_func_output --
1 4
2 -2
3 3
5 5

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("order_by_bind_priority") {
sql "SET enable_nereids_planner=true;"
sql "SET enable_fallback_to_original_planner=false;"

sql "drop table if exists t_order_by_bind_priority"
sql """create table t_order_by_bind_priority (c1 int, c2 int) distributed by hash(c1) properties("replication_num"="1");"""
sql "insert into t_order_by_bind_priority values(-2, -2),(1,2),(1,2),(3,3),(-5,5);"
sql "sync"


qt_test_bind_order_by_with_aggfun1 "select 2*abs(sum(c1)) as c1, c1,sum(c1)+c1 from t_order_by_bind_priority group by c1 order by sum(c1)+c1 asc;"
qt_test_bind_order_by_with_aggfun2 "select 2*abs(sum(c1)) as c2, c1,sum(c1)+c2 from t_order_by_bind_priority group by c1,c2 order by sum(c1)+c2 asc;"
qt_test_bind_order_by_with_aggfun3 "select abs(sum(c1)) as c1, c1,sum(c2) as c2 from t_order_by_bind_priority group by c1 order by sum(c1) asc;"
qt_test_bind_order_by_in_no_agg_func_output "select abs(c1) xx, sum(c2) from t_order_by_bind_priority group by xx order by min(xx)"
test {
sql "select abs(sum(c1)) as c1, c1,sum(c2) as c2 from t_order_by_bind_priority group by c1 order by sum(c1)+c2 asc;"
exception "c2 should be grouped by."
}


}