Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2567,7 +2567,7 @@ public PlanFragment visitPhysicalRepeat(PhysicalRepeat<? extends Plan> repeat, P
// cube and rollup already convert to grouping sets in LogicalPlanBuilder.withAggregate()
GroupingInfo groupingInfo = new GroupingInfo(outputTuple, preRepeatExprs);

List<Set<Integer>> repeatSlotIdList = repeat.computeRepeatSlotIdList(getSlotIds(outputTuple));
List<Set<Integer>> repeatSlotIdList = repeat.computeRepeatSlotIdList(getSlotIds(outputTuple), outputSlots);
Set<Integer> allSlotId = repeatSlotIdList.stream()
.flatMap(Set::stream)
.collect(ImmutableSet.toImmutableSet());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,9 @@ private LogicalUnion constructUnion(LogicalPlan aggregateProject, LogicalPlan di
* Determine if optimization is possible; if so, return the index of the largest group.
* The optimization requires:
* 1. The aggregate's child must be a LogicalRepeat
* 2. All aggregate functions must be Sum, Min, or Max (non-distinct)
* 3. No GroupingScalarFunction in repeat output
* 4. More than 3 grouping sets
* 5. There exists a grouping set that contains all other grouping sets
*
* 2. All aggregate functions must be in SUPPORT_AGG_FUNCTIONS.
* 3. More than 3 grouping sets
* 4. There exists a grouping set that contains all other grouping sets
* @param aggregate the aggregate plan to check
* @return value -1 means can not be optimized, values other than -1
* represent the index of the set that contains all other sets
Expand Down Expand Up @@ -401,7 +399,11 @@ private int canOptimize(LogicalAggregate<? extends Plan> aggregate) {
if (groupingSets.size() <= 3) {
return -1;
}
return findMaxGroupingSetIndex(groupingSets);
int maxGroupIndex = findMaxGroupingSetIndex(groupingSets);
if (maxGroupIndex < 0) {
return -1;
}
return maxGroupIndex;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@
package org.apache.doris.nereids.trees.plans.algebra;

import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.util.BitUtils;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
Expand Down Expand Up @@ -117,14 +116,35 @@ default List<List<Long>> computeGroupingFunctionsValues() {

/**
* flatten the grouping sets and build to a GroupingSetShapes.
* This method ensures that all expressions referenced by grouping functions are included
* in the flattenGroupingSetExpression, even if they are not in any grouping set.
* This is necessary for optimization scenarios where some expressions may only exist
* in the maximum grouping set that was removed during optimization.
*/
default GroupingSetShapes toShapes() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a ut for this function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Set<Expression> flattenGroupingSet = ImmutableSet.copyOf(ExpressionUtils.flatExpressions(getGroupingSets()));
// Collect all expressions referenced by grouping functions to ensure they are included
// in flattenGroupingSetExpression, even if they are not in any grouping set.
// This maintains semantic constraints while allowing optimization.
List<GroupingScalarFunction> groupingFunctions = ExpressionUtils.collectToList(
getOutputExpressions(), GroupingScalarFunction.class::isInstance);
Set<Expression> groupingFunctionArgs = Sets.newLinkedHashSet();
for (GroupingScalarFunction function : groupingFunctions) {
groupingFunctionArgs.addAll(function.getArguments());
}
// Merge grouping set expressions with grouping function arguments
// Use LinkedHashSet to preserve order: grouping sets first, then grouping function args
Set<Expression> flattenGroupingSet = Sets.newLinkedHashSet(getGroupByExpressions());
for (Expression arg : groupingFunctionArgs) {
if (!flattenGroupingSet.contains(arg)) {
flattenGroupingSet.add(arg);
}
}
List<GroupingSetShape> shapes = Lists.newArrayList();
for (List<Expression> groupingSet : getGroupingSets()) {
List<Boolean> shouldBeErasedToNull = Lists.newArrayListWithCapacity(flattenGroupingSet.size());
for (Expression groupingSetExpression : flattenGroupingSet) {
shouldBeErasedToNull.add(!groupingSet.contains(groupingSetExpression));
for (Expression expression : flattenGroupingSet) {
// If expression is not in the current grouping set, it should be erased to null
shouldBeErasedToNull.add(!groupingSet.contains(expression));
}
shapes.add(new GroupingSetShape(shouldBeErasedToNull));
}
Expand All @@ -140,8 +160,8 @@ default GroupingSetShapes toShapes() {
*
* return: [(4, 3), (3)]
*/
default List<Set<Integer>> computeRepeatSlotIdList(List<Integer> slotIdList) {
List<Set<Integer>> groupingSetsIndexesInOutput = getGroupingSetsIndexesInOutput();
default List<Set<Integer>> computeRepeatSlotIdList(List<Integer> slotIdList, List<Slot> outputSlots) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a ut for this function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

List<Set<Integer>> groupingSetsIndexesInOutput = getGroupingSetsIndexesInOutput(outputSlots);
List<Set<Integer>> repeatSlotIdList = Lists.newArrayList();
for (Set<Integer> groupingSetIndex : groupingSetsIndexesInOutput) {
// keep order
Expand All @@ -160,8 +180,8 @@ default List<Set<Integer>> computeRepeatSlotIdList(List<Integer> slotIdList) {
* e.g. groupingSets=((b, a), (a)), output=[a, b]
* return ((1, 0), (1))
*/
default List<Set<Integer>> getGroupingSetsIndexesInOutput() {
Map<Expression, Integer> indexMap = indexesOfOutput();
default List<Set<Integer>> getGroupingSetsIndexesInOutput(List<Slot> outputSlots) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a ut for this function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Map<Expression, Integer> indexMap = indexesOfOutput(outputSlots);

List<Set<Integer>> groupingSetsIndex = Lists.newArrayList();
List<List<Expression>> groupingSets = getGroupingSets();
Expand All @@ -184,23 +204,22 @@ default List<Set<Integer>> getGroupingSetsIndexesInOutput() {
/**
* indexesOfOutput: get the indexes which mapping from the expression to the index in the output.
*
* e.g. output=[a + 1, b + 2, c]
* e.g. outputSlots=[a + 1, b + 2, c]
*
* return the map(
* `a + 1`: 0,
* `b + 2`: 1,
* `c`: 2
* )
*
* Use outputSlots in physicalPlanTranslator instead of getOutputExpressions() in this method,
* because the outputSlots have same order with slotIdList.
*/
default Map<Expression, Integer> indexesOfOutput() {
static Map<Expression, Integer> indexesOfOutput(List<Slot> outputSlots) {
Map<Expression, Integer> indexes = Maps.newLinkedHashMap();
List<NamedExpression> outputs = getOutputExpressions();
for (int i = 0; i < outputs.size(); i++) {
NamedExpression output = outputs.get(i);
for (int i = 0; i < outputSlots.size(); i++) {
NamedExpression output = outputSlots.get(i);
indexes.put(output, i);
if (output instanceof Alias) {
indexes.put(((Alias) output).child(), i);
}
}
return indexes;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use it except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.plans.algebra;

import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingId;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.util.PlanConstructor;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Unit tests for {@link Repeat} interface default methods:
* toShapes, indexesOfOutput, getGroupingSetsIndexesInOutput, computeRepeatSlotIdList.
*/
public class RepeatTest {

private LogicalOlapScan scan;
private Slot id;
private Slot gender;
private Slot name;
private Slot age;

@BeforeEach
public void setUp() {
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("db"));
id = scan.getOutput().get(0);
gender = scan.getOutput().get(1);
name = scan.getOutput().get(2);
age = scan.getOutput().get(3);
}

@Test
public void testToShapes() {
// grouping sets: (id, name), (id), ()
// flatten = [id, name], shapes: [false,false], [false,true], [true,true]
List<List<Expression>> groupingSets = ImmutableList.of(
ImmutableList.of(id, name),
ImmutableList.of(id),
ImmutableList.of()
);
Alias alias = new Alias(new Sum(name), "sum(name)");
Repeat<Plan> repeat = new LogicalRepeat<>(
groupingSets,
ImmutableList.of(id, name, alias),
scan
);

Repeat.GroupingSetShapes shapes = repeat.toShapes();

Assertions.assertEquals(2, shapes.flattenGroupingSetExpression.size());
Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(id));
Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(name));
Assertions.assertEquals(3, shapes.shapes.size());

// (id, name) -> [false, false]
Assertions.assertFalse(shapes.shapes.get(0).shouldBeErasedToNull(0));
Assertions.assertFalse(shapes.shapes.get(0).shouldBeErasedToNull(1));
Assertions.assertEquals(0L, shapes.shapes.get(0).computeLongValue());

// (id) -> [false, true] (id in set, name not)
Assertions.assertFalse(shapes.shapes.get(1).shouldBeErasedToNull(0));
Assertions.assertTrue(shapes.shapes.get(1).shouldBeErasedToNull(1));
Assertions.assertEquals(1L, shapes.shapes.get(1).computeLongValue());

// () -> [true, true]
Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(0));
Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(1));
Assertions.assertEquals(3L, shapes.shapes.get(2).computeLongValue());
}

@Test
public void testToShapesWithGroupingFunction() {
// grouping(id) adds id to flatten if not present; single set (name) has flatten [name, id]
List<List<Expression>> groupingSets = ImmutableList.of(
ImmutableList.of(name), ImmutableList.of(name, id), ImmutableList.of());
Alias groupingAlias = new Alias(new GroupingId(gender, age), "grouping_id(id)");
Repeat<Plan> repeat = new LogicalRepeat<>(
groupingSets,
ImmutableList.of(name, groupingAlias),
scan
);

Repeat.GroupingSetShapes shapes = repeat.toShapes();

// flatten = [name] from getGroupBy + [id] from grouping function arg
Assertions.assertEquals(4, shapes.flattenGroupingSetExpression.size());
Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(name));
Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(id));
Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(gender));
Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(age));

Assertions.assertEquals(3, shapes.shapes.size());
// (name) -> name not erased, id,gender,age erased
Assertions.assertFalse(shapes.shapes.get(0).shouldBeErasedToNull(0));
Assertions.assertTrue(shapes.shapes.get(0).shouldBeErasedToNull(1));
Assertions.assertTrue(shapes.shapes.get(0).shouldBeErasedToNull(2));
Assertions.assertTrue(shapes.shapes.get(0).shouldBeErasedToNull(3));
// (name, id) -> name,id not erased, gender and age erased
Assertions.assertFalse(shapes.shapes.get(1).shouldBeErasedToNull(0));
Assertions.assertFalse(shapes.shapes.get(1).shouldBeErasedToNull(1));
Assertions.assertTrue(shapes.shapes.get(1).shouldBeErasedToNull(2));
Assertions.assertTrue(shapes.shapes.get(1).shouldBeErasedToNull(3));
//() -> all erased
Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(0));
Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(1));
Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(2));
Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(3));
}

@Test
public void testIndexesOfOutput() {
List<Slot> outputSlots = ImmutableList.of(id, gender, name, age);
Map<Expression, Integer> indexes = Repeat.indexesOfOutput(outputSlots);
Assertions.assertEquals(4, indexes.size());
Assertions.assertEquals(0, indexes.get(id));
Assertions.assertEquals(1, indexes.get(gender));
Assertions.assertEquals(2, indexes.get(name));
Assertions.assertEquals(3, indexes.get(age));
}

@Test
public void testGetGroupingSetsIndexesInOutput() {
// groupingSets=((name, id), (id), (gender)), output=[id, name, gender]
// expect:((1,0),(0),(2))
List<List<Expression>> groupingSets = ImmutableList.of(
ImmutableList.of(name, id),
ImmutableList.of(id),
ImmutableList.of(gender)
);
Alias groupingId = new Alias(new GroupingId(id, name));
Repeat<Plan> repeat = new LogicalRepeat<>(
groupingSets,
ImmutableList.of(id, name, gender, groupingId),
scan
);
List<Slot> outputSlots = ImmutableList.of(id, name, gender, groupingId.toSlot());

List<Set<Integer>> result = repeat.getGroupingSetsIndexesInOutput(outputSlots);

Assertions.assertEquals(3, result.size());
// (name, id) -> indexes {1, 0}
Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(1, 0)), result.get(0));
// (id) -> index {0}
Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(0)), result.get(1));
// (gender) -> index {2}
Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(2)), result.get(2));
}

@Test
public void testComputeRepeatSlotIdList() {
// groupingSets=((name, id), (id)), output=[id, name], slotIdList=[3, 4] (id->3, name->4)
List<List<Expression>> groupingSets = ImmutableList.of(
ImmutableList.of(name, id),
ImmutableList.of(id)
);
Repeat<Plan> repeat = new LogicalRepeat<>(
groupingSets,
ImmutableList.of(id, name),
scan
);
List<Slot> outputSlots = ImmutableList.of(id, name);
List<Integer> slotIdList = ImmutableList.of(3, 4);

List<Set<Integer>> result = repeat.computeRepeatSlotIdList(slotIdList, outputSlots);

Assertions.assertEquals(2, result.size());
// (name, id) -> indexes {1,0} -> slot ids {4, 3}
Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(4, 3)), result.get(0));
// (id) -> index {0} -> slot id {3}
Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(3)), result.get(1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
100000
100000
100000
100000

-- !sql_2_shape --
PhysicalCteAnchor ( cteId=CTEId#0 )
Expand All @@ -60,11 +59,9 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
-- !sql_2_result --
\N ALL 1 6 \N \N \N
\N ALL 1 6 \N \N \N
2020-01-02T00:00 ALL 1 6 \N 2020-01-02T00:00 \N
2020-01-02T00:00 ALL 1 6 \N 2020-01-02T00:00 \N
2020-01-03T00:00 ALL 1 6 \N 2020-01-03T00:00 \N
2020-01-03T00:00 ALL 1 6 \N 2020-01-03T00:00 \N
2020-01-04T00:00 ALL 1 6 \N 2020-01-04T00:00 \N
2020-01-04T00:00 ALL 1 6 \N 2020-01-04T00:00 \N
2020-01-04T00:00 ALL 1 6 \N \N a
2020-01-04T00:00 ALL 1 6 \N \N a
2020-01-04T00:00 ALL 1 6 \N \N b
2020-01-04T00:00 ALL 1 6 \N \N b
2020-01-04T00:00 ALL 1 7 \N \N \N

Loading