Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,11 @@ public class Rewriter extends AbstractBatchJobExecutor {
new EliminateFilter(),
new PushDownFilterThroughProject(),
new MergeProjects(),
new PruneOlapScanTablet()
new PruneOlapScanTablet(),
// SelectMaterializedIndexWithAggregate may change the nullability of agg functions
// need rerun AdjustAggregateNullableForEmptySet to make the nullability correct
// TODO: remove AdjustAggregateNullableForEmptySet when remove rbo mv selection rules
new AdjustAggregateNullableForEmptySet()
),
custom(RuleType.COLUMN_PRUNING, ColumnPruning::new),
bottomUp(RuleSet.PUSH_DOWN_FILTERS),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.RequireProperties;
import org.apache.doris.nereids.properties.RequirePropertiesSupplier;
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.plans.AggMode;
Expand Down Expand Up @@ -335,14 +337,32 @@ public PhysicalHashAggregate<CHILD_TYPE> setTopnPushInfo(TopnPushInfo topnPushIn
return this;
}

/**
* sql: select sum(distinct c1) from t;
* assume c1 is not null, because there is no group by
* sum(distinct c1)'s nullable is alwasNullable in rewritten phase.
* But in implementation phase, we may create 3 phase agg with group by key c1.
* And the sum(distinct c1)'s nullability should be changed depending on if there is any group by expressions.
* This pr update the agg function's nullability accordingly
*/
private List<NamedExpression> adjustNullableForOutputs(List<NamedExpression> outputs, boolean alwaysNullable) {
return ExpressionUtils.rewriteDownShortCircuit(outputs, output -> {
if (output instanceof NullableAggregateFunction
&& ((NullableAggregateFunction) output).isAlwaysNullable() != alwaysNullable) {
return ((NullableAggregateFunction) output).withAlwaysNullable(alwaysNullable);
} else {
return output;
if (output instanceof AggregateExpression) {
AggregateFunction function = ((AggregateExpression) output).getFunction();
if (function instanceof NullableAggregateFunction
&& ((NullableAggregateFunction) function).isAlwaysNullable() != alwaysNullable) {
AggregateParam param = ((AggregateExpression) output).getAggregateParam();
Expression child = ((AggregateExpression) output).child();
AggregateFunction newFunction = ((NullableAggregateFunction) function)
.withAlwaysNullable(alwaysNullable);
if (function == child) {
// function is also child
child = newFunction;
}
return new AggregateExpression(newFunction, param, child);
}
}
return output;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.AggMode;
Expand All @@ -54,6 +56,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class AggregateStrategiesTest implements MemoPatternMatchSupported {
Expand Down Expand Up @@ -380,6 +383,40 @@ public void distinctWithNormalAggregateFunctionApply4PhaseRule() {
);
}

@Test
public void distinctApply4PhaseRuleNullableChange() {
Slot id = rStudent.getOutput().get(0).toSlot();
List<Expression> groupExpressionList = Lists.newArrayList();
List<NamedExpression> outputExpressionList = Lists.newArrayList(
new Alias(new Count(true, id), "count_id"),
new Alias(new Sum(id), "sum_id"));
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList,
true, Optional.empty(), rStudent);

// select count(distinct id), sum(id) from t;
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyImplementation(fourPhaseAggregateWithDistinct())
.matches(
physicalHashAggregate(
physicalHashAggregate(
physicalHashAggregate(
physicalHashAggregate()
.when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL))
.when(agg -> agg.getGroupByExpressions().get(0).equals(id))
.when(agg -> verifyAlwaysNullableFlag(
agg.getAggregateFunctions(), false)))
.when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
.when(agg -> agg.getGroupByExpressions().get(0).equals(id))
.when(agg -> verifyAlwaysNullableFlag(agg.getAggregateFunctions(),
false)))
.when(agg -> agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> verifyAlwaysNullableFlag(agg.getAggregateFunctions(), true)))
.when(agg -> agg.getAggPhase().equals(AggPhase.DISTINCT_GLOBAL))
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> verifyAlwaysNullableFlag(agg.getAggregateFunctions(), true)));
}

private Rule twoPhaseAggregateWithoutDistinct() {
return new AggregateStrategies().buildRules()
.stream()
Expand All @@ -400,8 +437,18 @@ private Rule twoPhaseAggregateWithDistinct() {
private Rule fourPhaseAggregateWithDistinct() {
return new AggregateStrategies().buildRules()
.stream()
.filter(rule -> rule.getRuleType() == RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT)
.filter(rule -> rule.getRuleType() == RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT)
.findFirst()
.get();
}

private boolean verifyAlwaysNullableFlag(Set<AggregateFunction> functions, boolean alwaysNullable) {
for (AggregateFunction f : functions) {
if (f instanceof NullableAggregateFunction
&& ((NullableAggregateFunction) f).isAlwaysNullable() != alwaysNullable) {
return false;
}
}
return true;
}
}