diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index dbddc2e14bb690..0f6230a5fb409e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -541,26 +541,53 @@ private LogicalAggregate pushdownCountOnIndex( LogicalFilter filter, LogicalOlapScan olapScan, CascadesContext cascadesContext) { - PhysicalOlapScan physicalOlapScan - = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan() + + PhysicalOlapScan physicalOlapScan = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan() .build() .transform(olapScan, cascadesContext) .get(0); + + List argumentsOfAggregateFunction = normalizeArguments(agg.getAggregateFunctions(), project); + + if (!onlyContainsSlot(argumentsOfAggregateFunction)) { + return agg; + } + + return agg.withChildren(ImmutableList.of( + project != null + ? project.withChildren(ImmutableList.of( + filter.withChildren(ImmutableList.of( + new PhysicalStorageLayerAggregate( + physicalOlapScan, PushDownAggOp.COUNT_ON_MATCH))))) + : filter.withChildren(ImmutableList.of( + new PhysicalStorageLayerAggregate( + physicalOlapScan, PushDownAggOp.COUNT_ON_MATCH))) + )); + } + + private List normalizeArguments(Set aggregateFunctions, + @Nullable LogicalProject project) { + List arguments = aggregateFunctions.stream() + .flatMap(aggregateFunction -> aggregateFunction.getArguments().stream()) + .collect(ImmutableList.toImmutableList()); + if (project != null) { - return agg.withChildren(ImmutableList.of( - project.withChildren(ImmutableList.of( - filter.withChildren(ImmutableList.of( - new PhysicalStorageLayerAggregate( - physicalOlapScan, - PushDownAggOp.COUNT_ON_MATCH))))) - )); - } else { - return agg.withChildren(ImmutableList.of( - filter.withChildren(ImmutableList.of( - new PhysicalStorageLayerAggregate( - physicalOlapScan, - PushDownAggOp.COUNT_ON_MATCH))))); + arguments = Project.findProject(arguments, project.getProjects()) + .stream() + .map(p -> p instanceof Alias ? p.child(0) : p) + .collect(ImmutableList.toImmutableList()); } + + return arguments; + } + + private boolean onlyContainsSlot(List arguments) { + return arguments.stream().allMatch(argument -> { + if (argument instanceof SlotReference) { + return true; + } + return false; + }); } //select /*+SET_VAR(enable_pushdown_minmax_on_unique=true) */min(user_id) from table_unique; diff --git a/regression-test/data/inverted_index_p0/test_count_on_index.out b/regression-test/data/inverted_index_p0/test_count_on_index.out index 3c0f47e7f8baf9..f74f3dc927aeed 100644 --- a/regression-test/data/inverted_index_p0/test_count_on_index.out +++ b/regression-test/data/inverted_index_p0/test_count_on_index.out @@ -77,3 +77,6 @@ -- !sql_bad -- 0 1 +-- !sql_bad2 -- +0 1 + diff --git a/regression-test/suites/inverted_index_p0/test_count_on_index.groovy b/regression-test/suites/inverted_index_p0/test_count_on_index.groovy index 320fc65ff76bc9..ec2c556d8357e6 100644 --- a/regression-test/suites/inverted_index_p0/test_count_on_index.groovy +++ b/regression-test/suites/inverted_index_p0/test_count_on_index.groovy @@ -313,6 +313,30 @@ suite("test_count_on_index_httplogs", "p0") { contains "pushAggOp=NONE" } qt_sql_bad "${bad_sql}" + def bad_sql2 = """ + SELECT + COUNT(cond1) AS num1, + COUNT(cond2) AS num2 + FROM ( + SELECT + CASE + WHEN c IN ('c1', 'c2', 'c3') AND d = 'd1' THEN b + END AS cond1, + CASE + WHEN e = 'e1' AND c IN ('c1', 'c2', 'c3') THEN b + END AS cond2 + FROM + ${tableName5} + WHERE + a = '2024-07-26' + AND e = 'e1' + ) AS project; + """ + explain { + sql("${bad_sql2}") + contains "pushAggOp=NONE" + } + qt_sql_bad2 "${bad_sql2}" } finally { //try_sql("DROP TABLE IF EXISTS ${testTable}") }