-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[fix](agg) Fix grouping function handling and repeatSlotIdList calculation in DecomposeRepeatWithPreAggregation #60091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
4403516
43a9454
cbe1836
8b4c951
ad4f86f
ba33730
b51590d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,16 +18,15 @@ | |
| package org.apache.doris.nereids.trees.plans.algebra; | ||
|
|
||
| import org.apache.doris.nereids.exceptions.AnalysisException; | ||
| import org.apache.doris.nereids.trees.expressions.Alias; | ||
| import org.apache.doris.nereids.trees.expressions.Expression; | ||
| import org.apache.doris.nereids.trees.expressions.NamedExpression; | ||
| import org.apache.doris.nereids.trees.expressions.Slot; | ||
| import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; | ||
| import org.apache.doris.nereids.trees.plans.Plan; | ||
| import org.apache.doris.nereids.util.BitUtils; | ||
| import org.apache.doris.nereids.util.ExpressionUtils; | ||
|
|
||
| import com.google.common.collect.ImmutableList; | ||
| import com.google.common.collect.ImmutableSet; | ||
| import com.google.common.collect.Lists; | ||
| import com.google.common.collect.Maps; | ||
| import com.google.common.collect.Sets; | ||
|
|
@@ -117,14 +116,35 @@ default List<List<Long>> computeGroupingFunctionsValues() { | |
|
|
||
| /** | ||
| * flatten the grouping sets and build to a GroupingSetShapes. | ||
| * This method ensures that all expressions referenced by grouping functions are included | ||
| * in the flattenGroupingSetExpression, even if they are not in any grouping set. | ||
| * This is necessary for optimization scenarios where some expressions may only exist | ||
| * in the maximum grouping set that was removed during optimization. | ||
| */ | ||
| default GroupingSetShapes toShapes() { | ||
| Set<Expression> flattenGroupingSet = ImmutableSet.copyOf(ExpressionUtils.flatExpressions(getGroupingSets())); | ||
| // Collect all expressions referenced by grouping functions to ensure they are included | ||
| // in flattenGroupingSetExpression, even if they are not in any grouping set. | ||
| // This maintains semantic constraints while allowing optimization. | ||
| List<GroupingScalarFunction> groupingFunctions = ExpressionUtils.collectToList( | ||
| getOutputExpressions(), GroupingScalarFunction.class::isInstance); | ||
| Set<Expression> groupingFunctionArgs = Sets.newLinkedHashSet(); | ||
| for (GroupingScalarFunction function : groupingFunctions) { | ||
| groupingFunctionArgs.addAll(function.getArguments()); | ||
| } | ||
| // Merge grouping set expressions with grouping function arguments | ||
| // Use LinkedHashSet to preserve order: grouping sets first, then grouping function args | ||
| Set<Expression> flattenGroupingSet = Sets.newLinkedHashSet(getGroupByExpressions()); | ||
| for (Expression arg : groupingFunctionArgs) { | ||
| if (!flattenGroupingSet.contains(arg)) { | ||
| flattenGroupingSet.add(arg); | ||
| } | ||
| } | ||
| List<GroupingSetShape> shapes = Lists.newArrayList(); | ||
| for (List<Expression> groupingSet : getGroupingSets()) { | ||
| List<Boolean> shouldBeErasedToNull = Lists.newArrayListWithCapacity(flattenGroupingSet.size()); | ||
| for (Expression groupingSetExpression : flattenGroupingSet) { | ||
| shouldBeErasedToNull.add(!groupingSet.contains(groupingSetExpression)); | ||
| for (Expression expression : flattenGroupingSet) { | ||
| // If expression is not in the current grouping set, it should be erased to null | ||
| shouldBeErasedToNull.add(!groupingSet.contains(expression)); | ||
| } | ||
| shapes.add(new GroupingSetShape(shouldBeErasedToNull)); | ||
| } | ||
|
|
@@ -140,8 +160,8 @@ default GroupingSetShapes toShapes() { | |
| * | ||
| * return: [(4, 3), (3)] | ||
| */ | ||
| default List<Set<Integer>> computeRepeatSlotIdList(List<Integer> slotIdList) { | ||
| List<Set<Integer>> groupingSetsIndexesInOutput = getGroupingSetsIndexesInOutput(); | ||
| default List<Set<Integer>> computeRepeatSlotIdList(List<Integer> slotIdList, List<Slot> outputSlots) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a ut for this function
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| List<Set<Integer>> groupingSetsIndexesInOutput = getGroupingSetsIndexesInOutput(outputSlots); | ||
| List<Set<Integer>> repeatSlotIdList = Lists.newArrayList(); | ||
| for (Set<Integer> groupingSetIndex : groupingSetsIndexesInOutput) { | ||
| // keep order | ||
|
|
@@ -160,8 +180,8 @@ default List<Set<Integer>> computeRepeatSlotIdList(List<Integer> slotIdList) { | |
| * e.g. groupingSets=((b, a), (a)), output=[a, b] | ||
| * return ((1, 0), (1)) | ||
| */ | ||
| default List<Set<Integer>> getGroupingSetsIndexesInOutput() { | ||
| Map<Expression, Integer> indexMap = indexesOfOutput(); | ||
| default List<Set<Integer>> getGroupingSetsIndexesInOutput(List<Slot> outputSlots) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a ut for this function
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| Map<Expression, Integer> indexMap = indexesOfOutput(outputSlots); | ||
|
|
||
| List<Set<Integer>> groupingSetsIndex = Lists.newArrayList(); | ||
| List<List<Expression>> groupingSets = getGroupingSets(); | ||
|
|
@@ -184,23 +204,22 @@ default List<Set<Integer>> getGroupingSetsIndexesInOutput() { | |
| /** | ||
| * indexesOfOutput: get the indexes which mapping from the expression to the index in the output. | ||
| * | ||
| * e.g. output=[a + 1, b + 2, c] | ||
| * e.g. outputSlots=[a + 1, b + 2, c] | ||
| * | ||
| * return the map( | ||
| * `a + 1`: 0, | ||
| * `b + 2`: 1, | ||
| * `c`: 2 | ||
| * ) | ||
| * | ||
| * Use outputSlots in physicalPlanTranslator instead of getOutputExpressions() in this method, | ||
| * because the outputSlots have same order with slotIdList. | ||
| */ | ||
| default Map<Expression, Integer> indexesOfOutput() { | ||
| static Map<Expression, Integer> indexesOfOutput(List<Slot> outputSlots) { | ||
| Map<Expression, Integer> indexes = Maps.newLinkedHashMap(); | ||
| List<NamedExpression> outputs = getOutputExpressions(); | ||
| for (int i = 0; i < outputs.size(); i++) { | ||
| NamedExpression output = outputs.get(i); | ||
| for (int i = 0; i < outputSlots.size(); i++) { | ||
| NamedExpression output = outputSlots.get(i); | ||
| indexes.put(output, i); | ||
| if (output instanceof Alias) { | ||
| indexes.put(((Alias) output).child(), i); | ||
| } | ||
| } | ||
| return indexes; | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,201 @@ | ||
| // Licensed to the Apache Software Foundation (ASF) under one | ||
| // or more contributor license agreements. See the NOTICE file | ||
| // distributed with this work for additional information | ||
| // regarding copyright ownership. The ASF licenses this file | ||
| // to you under the Apache License, Version 2.0 (the | ||
| // "License"); you may not use it except in compliance | ||
| // with the License. You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, | ||
| // software distributed under the License is distributed on an | ||
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| // KIND, either express or implied. See the License for the | ||
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| package org.apache.doris.nereids.trees.plans.algebra; | ||
|
|
||
| import org.apache.doris.nereids.trees.expressions.Alias; | ||
| import org.apache.doris.nereids.trees.expressions.Expression; | ||
| import org.apache.doris.nereids.trees.expressions.Slot; | ||
| import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; | ||
| import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingId; | ||
| import org.apache.doris.nereids.trees.plans.Plan; | ||
| import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; | ||
| import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; | ||
| import org.apache.doris.nereids.util.PlanConstructor; | ||
|
|
||
| import com.google.common.collect.ImmutableList; | ||
| import com.google.common.collect.Sets; | ||
| import org.junit.jupiter.api.Assertions; | ||
| import org.junit.jupiter.api.BeforeEach; | ||
| import org.junit.jupiter.api.Test; | ||
|
|
||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.Set; | ||
|
|
||
| /** | ||
| * Unit tests for {@link Repeat} interface default methods: | ||
| * toShapes, indexesOfOutput, getGroupingSetsIndexesInOutput, computeRepeatSlotIdList. | ||
| */ | ||
| public class RepeatTest { | ||
|
|
||
| private LogicalOlapScan scan; | ||
| private Slot id; | ||
| private Slot gender; | ||
| private Slot name; | ||
| private Slot age; | ||
|
|
||
| @BeforeEach | ||
| public void setUp() { | ||
| scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("db")); | ||
| id = scan.getOutput().get(0); | ||
| gender = scan.getOutput().get(1); | ||
| name = scan.getOutput().get(2); | ||
| age = scan.getOutput().get(3); | ||
| } | ||
|
|
||
| @Test | ||
| public void testToShapes() { | ||
| // grouping sets: (id, name), (id), () | ||
| // flatten = [id, name], shapes: [false,false], [false,true], [true,true] | ||
| List<List<Expression>> groupingSets = ImmutableList.of( | ||
| ImmutableList.of(id, name), | ||
| ImmutableList.of(id), | ||
| ImmutableList.of() | ||
| ); | ||
| Alias alias = new Alias(new Sum(name), "sum(name)"); | ||
| Repeat<Plan> repeat = new LogicalRepeat<>( | ||
| groupingSets, | ||
| ImmutableList.of(id, name, alias), | ||
| scan | ||
| ); | ||
|
|
||
| Repeat.GroupingSetShapes shapes = repeat.toShapes(); | ||
|
|
||
| Assertions.assertEquals(2, shapes.flattenGroupingSetExpression.size()); | ||
| Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(id)); | ||
| Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(name)); | ||
| Assertions.assertEquals(3, shapes.shapes.size()); | ||
|
|
||
| // (id, name) -> [false, false] | ||
| Assertions.assertFalse(shapes.shapes.get(0).shouldBeErasedToNull(0)); | ||
| Assertions.assertFalse(shapes.shapes.get(0).shouldBeErasedToNull(1)); | ||
| Assertions.assertEquals(0L, shapes.shapes.get(0).computeLongValue()); | ||
|
|
||
| // (id) -> [false, true] (id in set, name not) | ||
| Assertions.assertFalse(shapes.shapes.get(1).shouldBeErasedToNull(0)); | ||
| Assertions.assertTrue(shapes.shapes.get(1).shouldBeErasedToNull(1)); | ||
| Assertions.assertEquals(1L, shapes.shapes.get(1).computeLongValue()); | ||
|
|
||
| // () -> [true, true] | ||
| Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(0)); | ||
| Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(1)); | ||
| Assertions.assertEquals(3L, shapes.shapes.get(2).computeLongValue()); | ||
| } | ||
|
|
||
| @Test | ||
| public void testToShapesWithGroupingFunction() { | ||
| // grouping(id) adds id to flatten if not present; single set (name) has flatten [name, id] | ||
| List<List<Expression>> groupingSets = ImmutableList.of( | ||
| ImmutableList.of(name), ImmutableList.of(name, id), ImmutableList.of()); | ||
| Alias groupingAlias = new Alias(new GroupingId(gender, age), "grouping_id(id)"); | ||
| Repeat<Plan> repeat = new LogicalRepeat<>( | ||
| groupingSets, | ||
| ImmutableList.of(name, groupingAlias), | ||
| scan | ||
| ); | ||
|
|
||
| Repeat.GroupingSetShapes shapes = repeat.toShapes(); | ||
|
|
||
| // flatten = [name] from getGroupBy + [id] from grouping function arg | ||
| Assertions.assertEquals(4, shapes.flattenGroupingSetExpression.size()); | ||
| Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(name)); | ||
| Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(id)); | ||
| Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(gender)); | ||
| Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(age)); | ||
|
|
||
| Assertions.assertEquals(3, shapes.shapes.size()); | ||
| // (name) -> name not erased, id,gender,age erased | ||
| Assertions.assertFalse(shapes.shapes.get(0).shouldBeErasedToNull(0)); | ||
| Assertions.assertTrue(shapes.shapes.get(0).shouldBeErasedToNull(1)); | ||
| Assertions.assertTrue(shapes.shapes.get(0).shouldBeErasedToNull(2)); | ||
| Assertions.assertTrue(shapes.shapes.get(0).shouldBeErasedToNull(3)); | ||
| // (name, id) -> name,id not erased, gender and age erased | ||
| Assertions.assertFalse(shapes.shapes.get(1).shouldBeErasedToNull(0)); | ||
| Assertions.assertFalse(shapes.shapes.get(1).shouldBeErasedToNull(1)); | ||
| Assertions.assertTrue(shapes.shapes.get(1).shouldBeErasedToNull(2)); | ||
| Assertions.assertTrue(shapes.shapes.get(1).shouldBeErasedToNull(3)); | ||
| //() -> all erased | ||
| Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(0)); | ||
| Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(1)); | ||
| Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(2)); | ||
| Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(3)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testIndexesOfOutput() { | ||
| List<Slot> outputSlots = ImmutableList.of(id, gender, name, age); | ||
| Map<Expression, Integer> indexes = Repeat.indexesOfOutput(outputSlots); | ||
| Assertions.assertEquals(4, indexes.size()); | ||
| Assertions.assertEquals(0, indexes.get(id)); | ||
| Assertions.assertEquals(1, indexes.get(gender)); | ||
| Assertions.assertEquals(2, indexes.get(name)); | ||
| Assertions.assertEquals(3, indexes.get(age)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testGetGroupingSetsIndexesInOutput() { | ||
| // groupingSets=((name, id), (id), (gender)), output=[id, name, gender] | ||
| // expect:((1,0),(0),(2)) | ||
| List<List<Expression>> groupingSets = ImmutableList.of( | ||
| ImmutableList.of(name, id), | ||
| ImmutableList.of(id), | ||
| ImmutableList.of(gender) | ||
| ); | ||
| Alias groupingId = new Alias(new GroupingId(id, name)); | ||
| Repeat<Plan> repeat = new LogicalRepeat<>( | ||
| groupingSets, | ||
| ImmutableList.of(id, name, gender, groupingId), | ||
| scan | ||
| ); | ||
| List<Slot> outputSlots = ImmutableList.of(id, name, gender, groupingId.toSlot()); | ||
|
|
||
| List<Set<Integer>> result = repeat.getGroupingSetsIndexesInOutput(outputSlots); | ||
|
|
||
| Assertions.assertEquals(3, result.size()); | ||
| // (name, id) -> indexes {1, 0} | ||
| Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(1, 0)), result.get(0)); | ||
| // (id) -> index {0} | ||
| Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(0)), result.get(1)); | ||
| // (gender) -> index {2} | ||
| Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(2)), result.get(2)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testComputeRepeatSlotIdList() { | ||
| // groupingSets=((name, id), (id)), output=[id, name], slotIdList=[3, 4] (id->3, name->4) | ||
| List<List<Expression>> groupingSets = ImmutableList.of( | ||
| ImmutableList.of(name, id), | ||
| ImmutableList.of(id) | ||
| ); | ||
| Repeat<Plan> repeat = new LogicalRepeat<>( | ||
| groupingSets, | ||
| ImmutableList.of(id, name), | ||
| scan | ||
| ); | ||
| List<Slot> outputSlots = ImmutableList.of(id, name); | ||
| List<Integer> slotIdList = ImmutableList.of(3, 4); | ||
|
|
||
| List<Set<Integer>> result = repeat.computeRepeatSlotIdList(slotIdList, outputSlots); | ||
|
|
||
| Assertions.assertEquals(2, result.size()); | ||
| // (name, id) -> indexes {1,0} -> slot ids {4, 3} | ||
| Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(4, 3)), result.get(0)); | ||
| // (id) -> index {0} -> slot id {3} | ||
| Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(3)), result.get(1)); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a ut for this function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done