From aa4ffd21b1c4f7f730325029c713a0008aba95f5 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Wed, 7 Jan 2026 16:28:06 +0800 Subject: [PATCH] fix --- .../nereids/rules/rewrite/MergeAggregate.java | 1 + .../rules/rewrite/MergeAggregateTest.java | 124 ++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergeAggregateTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 4c02f53bdec082..346fc7edf92369 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -100,6 +100,7 @@ private Plan mergeAggProjectAgg(LogicalAggregate innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg); // rewrite agg function. e.g. max(max) List replacedAggFunc = replacedOutputExpressions.stream() + .filter(e -> e.containsType(AggregateFunction.class)) .map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc)) .collect(Collectors.toList()); // replace groupByKeys directly refer to the slot below the project diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergeAggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergeAggregateTest.java new file mode 100644 index 00000000000000..6d69cd79f983d3 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergeAggregateTest.java @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.types.IntegerType; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.List; + +/** + * Unit tests for {@link MergeAggregate}, specifically testing the fix for filtering + * aggregate functions in mergeAggProjectAgg method. + */ +public class MergeAggregateTest { + + private MergeAggregate mergeAggregate; + + @BeforeEach + public void setUp() { + mergeAggregate = new MergeAggregate(); + } + + @Test + public void testMergeAggProjectAggWithMixedExpressions() throws Exception { + // This test verifies the fix at line 103-104 where we filter expressions + // to only process those containing AggregateFunction. + // The bug was that non-aggregate expressions (like SlotReference) were + // being passed to rewriteAggregateFunction, which could cause errors. + + // Create inner aggregate: group by a, output a, sum(b) as sumBAlias + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE); + Sum sumB = new Sum(b); + Alias sumBAlias = new Alias(sumB, "sumBAlias"); + + LogicalEmptyRelation emptyRelation = new LogicalEmptyRelation( + org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator.newRelationId(), + ImmutableList.of()); + + LogicalAggregate innerAgg = new LogicalAggregate<>( + ImmutableList.of(a), + ImmutableList.of(a, sumBAlias), + emptyRelation); + + // Create project: projects = [a as colA, sumBAlias] + SlotReference colA = new SlotReference( + sumBAlias.getExprId(), "colA", IntegerType.INSTANCE, true, ImmutableList.of()); + // Create a slot reference for sumBAlias from inner aggregate output + Slot sumBSlot = sumBAlias.toSlot(); + LogicalProject> project = new LogicalProject<>( + ImmutableList.of(colA, sumBSlot), + innerAgg); + + // Create outer aggregate: group by colA, output colA, sum(sumBAlias) + Slot col2FromProject = project.getOutput().get(0); + Slot col1FromProject = project.getOutput().get(1); + Sum sumSum = new Sum(col1FromProject); + Alias sumSumAlias = new Alias(sumSum, "sumSum"); + + // Outer aggregate output contains: + // 1. colA (SlotReference - non-aggregate, should be filtered out) + // 2. max(sumBAlias) (AggregateFunction - should be processed) + // 3. sumBAlias (SlotReference - non-aggregate, should be filtered out) + List outerAggOutput = ImmutableList.of( + col2FromProject, + sumSumAlias + ); + + LogicalAggregate>> outerAgg = new LogicalAggregate<>( + ImmutableList.of(col2FromProject), + outerAggOutput, + project); + + // Use reflection to call the private method + Method method = mergeAggregate.getClass().getDeclaredMethod( + "mergeAggProjectAgg", LogicalAggregate.class); + method.setAccessible(true); + + // This should not throw an exception + // Before the fix, non-aggregate expressions would be passed to rewriteAggregateFunction + // which could cause errors. After the fix, only expressions containing AggregateFunction + // are processed. + Plan result = (Plan) method.invoke(mergeAggregate, outerAgg); + + Assertions.assertNotNull(result); + Assertions.assertTrue(result instanceof LogicalProject); + + LogicalProject resultProject = (LogicalProject) result; + Assertions.assertNotNull(resultProject.child(0)); + Assertions.assertTrue(resultProject.child(0) instanceof LogicalAggregate); + + LogicalAggregate aggregate = (LogicalAggregate) resultProject.child(0); + Assertions.assertEquals(aggregate.getOutput().size(), 2); + } +}