diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectStmt.java index 067a8882debe1a..8ff2b7a3bbc2a5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectStmt.java @@ -60,8 +60,10 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Queue; import java.util.Set; import java.util.UUID; import java.util.stream.Collectors; @@ -439,17 +441,7 @@ public void analyze(Analyzer analyzer) throws UserException { } } if (groupByClause != null && groupByClause.isGroupByExtension()) { - for (SelectListItem item : selectList.getItems()) { - if (item.getExpr() instanceof FunctionCallExpr && item.getExpr().fn instanceof AggregateFunction) { - for (Expr expr : groupByClause.getGroupingExprs()) { - if (item.getExpr().contains(expr)) { - throw new AnalysisException("column: " + expr.toSql() + " cannot both in select list and " - + "aggregate functions when using GROUPING SETS/CUBE/ROLLUP, please use union" - + " instead."); - } - } - } - } + checkSelectItemsForGroupingSet(); groupingInfo = new GroupingInfo(analyzer, groupByClause); groupingInfo.substituteGroupingFn(resultExprs, analyzer); } else { @@ -562,6 +554,48 @@ public void analyze(Analyzer analyzer) throws UserException { } } + /** + * check whether grouping set columns are in the agg function + * within the select items. If true, throw an AnalysisException. + * + * @throws AnalysisException when check failed + */ + public void checkSelectItemsForGroupingSet() throws AnalysisException { + for (SelectListItem item : selectList.getItems()) { + Expr selectExprRoot = item.getExpr(); + List aggFunctions = getAggFuncExprsFromChildren(selectExprRoot); + for (Expr aggFunction : aggFunctions) { + for (Expr groupingExpr : groupByClause.getGroupingExprs()) { + if (aggFunction.contains(groupingExpr)) { + throw new AnalysisException("column: " + groupingExpr.toSql() + " cannot both in" + + " select list and aggregate functions when using GROUPING SETS/CUBE/ROLLUP," + + " please use union instead."); + } + } + } + } + } + + /** + * Get all AggregateFunctions,which are under the `expr` in the Expr-tree. + * + * @param expr + * + * @return list of exprs + */ + public List getAggFuncExprsFromChildren(Expr expr) { + List aggFuncExprs = new LinkedList<>(); + Queue exprsQueue = new LinkedList<>(); + exprsQueue.offer(expr); + while (!exprsQueue.isEmpty()) { + Expr exprChild = exprsQueue.poll(); + if (exprChild instanceof FunctionCallExpr && exprChild.getFn() instanceof AggregateFunction) { + aggFuncExprs.add(exprChild); + } + exprsQueue.addAll(exprChild.getChildrenWithoutCast()); + } + return aggFuncExprs; + } public List getTableRefIds() { List result = Lists.newArrayList(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java b/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java index 73fe9bf24bfddc..11eed15fbbc229 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java @@ -452,6 +452,17 @@ public void testStringType() { Assertions.assertTrue(exception.getMessage().contains("String Type should not be used in key column[k1].")); } + @Test + public void testSelectAggregateItemCheckOnGroupingSet() throws Exception { + String sql = "explain select k1,if(k2=null, null, count(distinct k2)) from db1.tbl4" + + " group by grouping sets((k1),(k1,k2))"; + String errorMessage = "errCode = 2, detailMessage = column: `k2` cannot both in select list and " + + "aggregate functions when using GROUPING SETS/CUBE/ROLLUP, please use union instead."; + StmtExecutor stmtExecutor = new StmtExecutor(connectContext, sql); + stmtExecutor.execute(); + Assertions.assertTrue(connectContext.getState().getErrorMessage().contains(errorMessage)); + } + @Test public void testPushDownPredicateOnGroupingSetAggregate() throws Exception { String sql = "explain select k1, k2, count(distinct v1) from db1.tbl4"