From 5662782a339e52032ab3a56e20d7f70095f86c25 Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Sun, 12 May 2024 20:17:10 +0800 Subject: [PATCH 1/5] [opt](nereids)new way to set pre-agg status --- .../SelectMaterializedIndexWithAggregate.java | 345 +++++++++++++++++- .../rewrite/mv/SelectRollupIndexTest.java | 12 +- 2 files changed, 348 insertions(+), 9 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java index b8aae5066862af..324f5ae4ef1bd3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java @@ -39,6 +39,7 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; @@ -76,6 +77,7 @@ import org.apache.doris.nereids.util.Utils; import org.apache.doris.planner.PlanNode; +import com.google.common.base.Preconditions; import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -633,8 +635,7 @@ private SelectResult select(LogicalOlapScan scan, Set requiredScanOutput, if ((new CheckContext(scan, selectIndexId)).isBaseIndex()) { PreAggStatus preagg = scan.getPreAggStatus(); if (preagg.isOn()) { - preagg = checkPreAggStatus(scan, scan.getTable().getBaseIndexId(), predicates, aggregateFunctions, - groupingExprs); + preagg = checkPreAggStatus(scan, predicates, aggregateFunctions, groupingExprs); } return new SelectResult(preagg, selectIndexId, new ExprRewriteMap()); } @@ -716,6 +717,346 @@ private PreAggStatus checkPreAggStatus(LogicalOlapScan olapScan, long indexId, S .offOrElse(() -> checkPredicates(ImmutableList.copyOf(predicates), checkContext)); } + private PreAggStatus checkPreAggStatus(LogicalOlapScan olapScan, Set predicates, + List aggregateFuncs, List groupingExprs) { + MaterializedIndexMeta meta = + olapScan.getTable().getIndexMetaByIndexId(olapScan.getSelectedIndexId()); + if (meta.getKeysType() == KeysType.DUP_KEYS || (meta.getKeysType() == KeysType.UNIQUE_KEYS + && olapScan.getTable().getEnableUniqueKeyMergeOnWrite())) { + return PreAggStatus.on(); + } + Set outputSlots = olapScan.getOutputSet(); + Pair, Set> splittedSlots = splitSlots(outputSlots); + Set keySlots = splittedSlots.first; + Set valueSlots = splittedSlots.second; + Preconditions.checkState(outputSlots.size() == keySlots.size() + valueSlots.size(), + "output slots contains no key or value slots"); + + Set groupInputSlots = ExpressionUtils.getInputSlotSet(groupingExprs); + if (groupInputSlots.retainAll(keySlots)) { + return PreAggStatus + .off(String.format("Grouping expression %s contains non-key column %s", + groupingExprs, groupInputSlots)); + } + + Set predicateInputSlots = ExpressionUtils.getInputSlotSet(predicates); + if (predicateInputSlots.retainAll(keySlots)) { + return PreAggStatus.off(String.format("Predicate %s contains non-key column %s", + predicates, predicateInputSlots)); + } + + return checkAggregateFunctions(aggregateFuncs, keySlots, valueSlots); + } + + private Pair, Set> splitSlots(Set slots) { + Set keySlots = Sets.newHashSetWithExpectedSize(slots.size()); + Set valueSlots = Sets.newHashSetWithExpectedSize(slots.size()); + for (Slot slot : slots) { + if (slot instanceof SlotReference && ((SlotReference) slot).getColumn().isPresent()) { + if (((SlotReference) slot).getColumn().get().isKey()) { + keySlots.add((SlotReference) slot); + } else { + valueSlots.add((SlotReference) slot); + } + } + } + return Pair.of(keySlots, valueSlots); + } + + private static class OneValueSlotAggChecker + extends ExpressionVisitor { + public static final OneValueSlotAggChecker INSTANCE = new OneValueSlotAggChecker(); + + public PreAggStatus check(AggregateFunction aggFun, AggregateType aggregateType) { + return aggFun.accept(INSTANCE, aggregateType); + } + + @Override + public PreAggStatus visit(Expression expr, AggregateType aggregateType) { + return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql())); + } + + @Override + public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, + AggregateType aggregateType) { + return PreAggStatus + .off(String.format("%s is not supported.", aggregateFunction.toSql())); + } + + @Override + public PreAggStatus visitMax(Max max, AggregateType aggregateType) { + if (aggregateType == AggregateType.MAX && !max.isDistinct()) { + return PreAggStatus.on(); + } else { + return PreAggStatus + .off(String.format("%s is not match agg mode %s or has distinct param", + max.toSql(), aggregateType)); + } + } + + @Override + public PreAggStatus visitMin(Min min, AggregateType aggregateType) { + if (aggregateType == AggregateType.MIN && !min.isDistinct()) { + return PreAggStatus.on(); + } else { + return PreAggStatus + .off(String.format("%s is not match agg mode %s or has distinct param", + min.toSql(), aggregateType)); + } + } + + @Override + public PreAggStatus visitSum(Sum sum, AggregateType aggregateType) { + if (aggregateType == AggregateType.SUM && !sum.isDistinct()) { + return PreAggStatus.on(); + } else { + return PreAggStatus + .off(String.format("%s is not match agg mode %s or has distinct param", + sum.toSql(), aggregateType)); + } + } + + @Override + public PreAggStatus visitBitmapUnionCount(BitmapUnionCount bitmapUnionCount, + AggregateType aggregateType) { + if (aggregateType == AggregateType.BITMAP_UNION) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off("invalid bitmap_union_count: " + bitmapUnionCount.toSql()); + } + } + + @Override + public PreAggStatus visitBitmapUnion(BitmapUnion bitmapUnion, AggregateType aggregateType) { + if (aggregateType == AggregateType.BITMAP_UNION) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off("invalid bitmapUnion: " + bitmapUnion.toSql()); + } + } + + @Override + public PreAggStatus visitHllUnionAgg(HllUnionAgg hllUnionAgg, AggregateType aggregateType) { + if (aggregateType == AggregateType.HLL_UNION) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off("invalid hllUnionAgg: " + hllUnionAgg.toSql()); + } + } + + @Override + public PreAggStatus visitHllUnion(HllUnion hllUnion, AggregateType aggregateType) { + if (aggregateType == AggregateType.HLL_UNION) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off("invalid hllUnion: " + hllUnion.toSql()); + } + } + } + + private static class OneKeySlotAggChecker extends ExpressionVisitor { + public static final OneKeySlotAggChecker INSTANCE = new OneKeySlotAggChecker(); + + public PreAggStatus check(AggregateFunction aggFun) { + return aggFun.accept(INSTANCE, null); + } + + @Override + public PreAggStatus visit(Expression expr, Void context) { + return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql())); + } + + @Override + public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, + Void context) { + return PreAggStatus.off(String.format("Aggregate function %s contains key column %s", + aggregateFunction.toSql(), aggregateFunction.child(0).toSql())); + } + + @Override + public PreAggStatus visitMax(Max max, Void context) { + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitMin(Min min, Void context) { + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitCount(Count count, Void context) { + if (count.isDistinct()) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off(String.format("%s is not distinct.", count.toSql())); + } + } + } + + private static class KeyAndValueSlotsAggChecker + extends ExpressionVisitor> { + public static final KeyAndValueSlotsAggChecker INSTANCE = new KeyAndValueSlotsAggChecker(); + + public PreAggStatus check(AggregateFunction aggFun, List returnValues) { + return aggFun.accept(INSTANCE, returnValues); + } + + @Override + public PreAggStatus visit(Expression expr, List returnValues) { + return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql())); + } + + @Override + public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, + List returnValues) { + return PreAggStatus + .off(String.format("%s is not supported.", aggregateFunction.toSql())); + } + + @Override + public PreAggStatus visitSum(Sum sum, List returnValues) { + for (Expression value : returnValues) { + if (!(isAggTypeMatched(value, AggregateType.SUM) || value.isZeroLiteral() + || value.isNullLiteral())) { + return PreAggStatus.off(String.format("%s is not supported.", sum.toSql())); + } + } + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitMax(Max max, List returnValues) { + for (Expression value : returnValues) { + if (!(isAggTypeMatched(value, AggregateType.MAX) || isKeySlot(value) + || value.isNullLiteral())) { + return PreAggStatus.off(String.format("%s is not supported.", max.toSql())); + } + } + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitMin(Min min, List returnValues) { + for (Expression value : returnValues) { + if (!(isAggTypeMatched(value, AggregateType.MIN) || isKeySlot(value) + || value.isNullLiteral())) { + return PreAggStatus.off(String.format("%s is not supported.", min.toSql())); + } + } + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitCount(Count count, List returnValues) { + if (count.isDistinct()) { + for (Expression value : returnValues) { + if (!(isKeySlot(value) || value.isZeroLiteral() || value.isNullLiteral())) { + return PreAggStatus + .off(String.format("%s is not supported.", count.toSql())); + } + } + return PreAggStatus.on(); + } else { + return PreAggStatus.off(String.format("%s is not supported.", count.toSql())); + } + } + + private boolean isKeySlot(Expression expression) { + return expression instanceof SlotReference + && ((SlotReference) expression).getColumn().isPresent() + && ((SlotReference) expression).getColumn().get().isKey(); + } + + private boolean isAggTypeMatched(Expression expression, AggregateType aggregateType) { + return expression instanceof SlotReference + && ((SlotReference) expression).getColumn().isPresent() + && ((SlotReference) expression).getColumn().get() + .getAggregationType() == aggregateType; + } + } + + private static Expression removeCast(Expression expression) { + while (expression instanceof Cast) { + expression = ((Cast) expression).child(); + } + return expression; + } + + private PreAggStatus checkAggWithKeyAndValueSlots(AggregateFunction aggFunc, + Set keySlots, Set valueSlots) { + Expression child = aggFunc.child(0); + List conditionExps = new ArrayList<>(); + List returnExps = new ArrayList<>(); + + // ignore cast + while (child instanceof Cast) { + if (!((Cast) child).getDataType().isNumericType()) { + return PreAggStatus.off(String.format("%s is not numeric CAST.", child.toSql())); + } + child = child.child(0); + } + // step 1: extract all condition exprs and return exprs + if (child instanceof If) { + conditionExps.add(child.child(0)); + returnExps.add(removeCast(child.child(1))); + returnExps.add(removeCast(child.child(2))); + } else if (child instanceof CaseWhen) { + CaseWhen caseWhen = (CaseWhen) child; + // WHEN THEN + for (WhenClause whenClause : caseWhen.getWhenClauses()) { + conditionExps.add(whenClause.getOperand()); + returnExps.add(removeCast(whenClause.getResult())); + } + // ELSE + returnExps.add(removeCast(caseWhen.getDefaultValue().orElse(new NullLiteral()))); + } else { + // currently, only IF and CASE WHEN are supported + returnExps.add(removeCast(child)); + } + + // step 2: check condition expressions + Set inputSlots = ExpressionUtils.getInputSlotSet(conditionExps); + inputSlots.retainAll(valueSlots); + if (!inputSlots.isEmpty()) { + return PreAggStatus + .off(String.format("some columns in condition %s is not key.", conditionExps)); + } + + return KeyAndValueSlotsAggChecker.INSTANCE.check(aggFunc, returnExps); + } + + private PreAggStatus checkAggregateFunctions(List aggregateFuncs, + Set keySlots, Set valueSlots) { + PreAggStatus preAggStatus = PreAggStatus.on(); + for (AggregateFunction aggFunc : aggregateFuncs) { + if (aggFunc.children().size() == 1 && aggFunc.child(0) instanceof Slot) { + Slot aggSlot = (Slot) aggFunc.child(0); + if (aggSlot instanceof SlotReference + && ((SlotReference) aggSlot).getColumn().isPresent()) { + if (((SlotReference) aggSlot).getColumn().get().isKey()) { + preAggStatus = OneKeySlotAggChecker.INSTANCE.check(aggFunc); + } else { + preAggStatus = OneValueSlotAggChecker.INSTANCE.check(aggFunc, + ((SlotReference) aggSlot).getColumn().get().getAggregationType()); + } + } else { + preAggStatus = PreAggStatus.off( + String.format("aggregate function %s use unknown slot %s from scan", + aggFunc, aggSlot)); + } + } else { + Set aggSlots = aggFunc.getInputSlots(); + Pair, Set> splitSlots = splitSlots(aggSlots); + preAggStatus = + checkAggWithKeyAndValueSlots(aggFunc, splitSlots.first, splitSlots.second); + } + if (preAggStatus.isOff()) { + return preAggStatus; + } + } + return preAggStatus; + } + /** * Check pre agg status according to aggregate functions. */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java index 0686edba64e01e..beb9029e773508 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java @@ -139,8 +139,7 @@ void testTranslate() { public void testTranslateWhenPreAggIsOff() { singleTableTest("select k2, min(v1) from t group by k2", scan -> { Assertions.assertFalse(scan.isPreAggregation()); - Assertions.assertEquals("Aggregate operator don't match, " - + "aggregate function: min(v1), column aggregate type: SUM", + Assertions.assertEquals("min(v1) is not match agg mode SUM or has distinct param", scan.getReasonOfPreAggregation()); }); } @@ -227,8 +226,7 @@ public void testAggregateTypeNotMatch() { .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); - Assertions.assertEquals("Aggregate operator don't match, " - + "aggregate function: min(v1), column aggregate type: SUM", preAgg.getOffReason()); + Assertions.assertEquals("min(v1) is not match agg mode SUM or has distinct param", preAgg.getOffReason()); return true; })); } @@ -242,7 +240,7 @@ public void testInvalidSlotInAggFunction() { .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); - Assertions.assertEquals("do not support compound expression [(v1 + 1)] in SUM.", + Assertions.assertEquals("sum((v1 + 1)) is not supported.", preAgg.getOffReason()); return true; })); @@ -257,7 +255,7 @@ public void testKeyColumnInAggFunction() { .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); - Assertions.assertEquals("Aggregate function sum(k2) contains key column k2.", + Assertions.assertEquals("Aggregate function sum(k2) contains key column k2", preAgg.getOffReason()); return true; })); @@ -402,7 +400,7 @@ public void testCountDistinctKeyColumn() { public void testCountDistinctValueColumn() { singleTableTest("select k1, count(distinct v1) from t group by k1", scan -> { Assertions.assertFalse(scan.isPreAggregation()); - Assertions.assertEquals("Count distinct is only valid for key columns, but meet count(DISTINCT v1).", + Assertions.assertEquals("count(DISTINCT v1) is not supported.", scan.getReasonOfPreAggregation()); Assertions.assertEquals("t", scan.getSelectedIndexName()); }); From 1dcd14afafe25c03cd391ed514120919cdf9281a Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Tue, 14 May 2024 16:05:10 +0800 Subject: [PATCH 2/5] refactor --- .../doris/nereids/jobs/executor/Rewriter.java | 4 + .../apache/doris/nereids/rules/RuleType.java | 15 + .../nereids/rules/analysis/BindRelation.java | 2 +- .../rules/rewrite/AdjustPreAggStatus.java | 743 ++++++++++++++++++ .../AbstractSelectMaterializedIndexRule.java | 12 +- .../SelectMaterializedIndexWithAggregate.java | 363 +-------- ...lectMaterializedIndexWithoutAggregate.java | 20 +- .../nereids/trees/plans/PreAggStatus.java | 15 +- .../trees/plans/logical/LogicalOlapScan.java | 14 +- .../rewrite/mv/SelectRollupIndexTest.java | 18 + 10 files changed, 814 insertions(+), 392 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 2bc61b3b6fe1a9..e2095248298145 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -36,6 +36,7 @@ import org.apache.doris.nereids.rules.rewrite.AddDefaultLimit; import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType; import org.apache.doris.nereids.rules.rewrite.AdjustNullable; +import org.apache.doris.nereids.rules.rewrite.AdjustPreAggStatus; import org.apache.doris.nereids.rules.rewrite.AggScalarSubQueryToWindowFunction; import org.apache.doris.nereids.rules.rewrite.BuildAggForUnion; import org.apache.doris.nereids.rules.rewrite.CTEInline; @@ -391,6 +392,9 @@ public class Rewriter extends AbstractBatchJobExecutor { bottomUp(RuleSet.PUSH_DOWN_FILTERS), custom(RuleType.ELIMINATE_UNNECESSARY_PROJECT, EliminateUnnecessaryProject::new) ), + topic("adjust preagg status", + topDown(new AdjustPreAggStatus()) + ), topic("topn optimize", topDown(new DeferMaterializeTopNResult()) ), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index c1c6f539c6670c..c0aadf1e730eec 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -241,6 +241,21 @@ public enum RuleType { MATERIALIZED_INDEX_PROJECT_SCAN(RuleTypeClass.REWRITE), MATERIALIZED_INDEX_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), MATERIALIZED_INDEX_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_FILTER_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_PROJECT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_REPEAT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_REPEAT_FILTER_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_REPEAT_PROJECT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_REPEAT_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_REPEAT_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_FILTER_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_PROJECT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), REDUCE_AGGREGATE_CHILD_OUTPUT_ROWS(RuleTypeClass.REWRITE), OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java index 0e6d940891ebce..df3743928a9b96 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java @@ -206,7 +206,7 @@ private LogicalPlan makeOlapScan(TableIf table, UnboundRelation unboundRelation, } PreAggStatus preAggStatus = olapTable.getIndexMetaByIndexId(indexId).getKeysType().equals(KeysType.DUP_KEYS) - ? PreAggStatus.on() + ? PreAggStatus.unset() : PreAggStatus.off("For direct index scan."); scan = new LogicalOlapScan(unboundRelation.getRelationId(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java new file mode 100644 index 00000000000000..c867cb10b4ea64 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java @@ -0,0 +1,743 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.catalog.AggregateType; +import org.apache.doris.catalog.KeysType; +import org.apache.doris.catalog.MaterializedIndexMeta; +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.annotation.Developing; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.PreAggStatus; +import org.apache.doris.nereids.trees.plans.algebra.Project; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * AdjustPreAggStatus + */ +@Developing +public class AdjustPreAggStatus implements RewriteRuleFactory { + private static Expression removeCast(Expression expression) { + while (expression instanceof Cast) { + expression = ((Cast) expression).child(); + } + return expression; + } + + /////////////////////////////////////////////////////////////////////////// + // All the patterns + /////////////////////////////////////////////////////////////////////////// + @Override + public List buildRules() { + return ImmutableList.of( + // Aggregate(Scan) + logicalAggregate(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)) + .thenApplyNoThrow(ctx -> { + LogicalAggregate agg = ctx.root; + LogicalOlapScan scan = agg.child(); + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = agg.getGroupByExpressions(); + Set predicates = ImmutableSet.of(); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return agg.withChildren(scan.withPreAggStatus(preAggStatus)); + }).toRule(RuleType.PREAGG_STATUS_AGG_SCAN), + + // Aggregate(Filter(Scan)) + logicalAggregate( + logicalFilter(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate> agg = ctx.root; + LogicalFilter filter = agg.child(); + LogicalOlapScan scan = filter.child(); + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = + agg.getGroupByExpressions(); + Set predicates = filter.getConjuncts(); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return agg.withChildren(filter + .withChildren(scan.withPreAggStatus(preAggStatus))); + }).toRule(RuleType.PREAGG_STATUS_AGG_FILTER_SCAN), + + // Aggregate(Project(Scan)) + logicalAggregate(logicalProject( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate> agg = + ctx.root; + LogicalProject project = agg.child(); + LogicalOlapScan scan = project.child(); + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, + Optional.of(project)); + List groupByExpressions = + ExpressionUtils.replace(agg.getGroupByExpressions(), + project.getAliasToProducer()); + Set predicates = ImmutableSet.of(); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return agg.withChildren(project + .withChildren(scan.withPreAggStatus(preAggStatus))); + }).toRule(RuleType.PREAGG_STATUS_AGG_PROJECT_SCAN), + + // Aggregate(Project(Filter(Scan))) + logicalAggregate(logicalProject(logicalFilter( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate>> agg = ctx.root; + LogicalProject> project = agg.child(); + LogicalFilter filter = project.child(); + LogicalOlapScan scan = filter.child(); + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.of(project)); + List groupByExpressions = + ExpressionUtils.replace(agg.getGroupByExpressions(), + project.getAliasToProducer()); + Set predicates = filter.getConjuncts(); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return agg.withChildren(project.withChildren(filter + .withChildren(scan.withPreAggStatus(preAggStatus)))); + }).toRule(RuleType.PREAGG_STATUS_AGG_PROJECT_FILTER_SCAN), + + // Aggregate(Filter(Project(Scan))) + logicalAggregate(logicalFilter(logicalProject( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate>> agg = ctx.root; + LogicalFilter> filter = + agg.child(); + LogicalProject project = filter.child(); + LogicalOlapScan scan = project.child(); + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.of(project)); + List groupByExpressions = + ExpressionUtils.replace(agg.getGroupByExpressions(), + project.getAliasToProducer()); + Set predicates = ExpressionUtils.replace( + filter.getConjuncts(), project.getAliasToProducer()); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return agg.withChildren(filter.withChildren(project + .withChildren(scan.withPreAggStatus(preAggStatus)))); + }).toRule(RuleType.PREAGG_STATUS_AGG_FILTER_PROJECT_SCAN), + + // Aggregate(Repeat(Scan)) + logicalAggregate( + logicalRepeat(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate> agg = ctx.root; + LogicalRepeat repeat = agg.child(); + LogicalOlapScan scan = repeat.child(); + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = nonVirtualGroupByExprs(agg); + Set predicates = ImmutableSet.of(); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return agg.withChildren(repeat + .withChildren(scan.withPreAggStatus(preAggStatus))); + }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_SCAN), + + // Aggregate(Repeat(Filter(Scan))) + logicalAggregate(logicalRepeat(logicalFilter( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate>> agg = ctx.root; + LogicalRepeat> repeat = agg.child(); + LogicalFilter filter = repeat.child(); + LogicalOlapScan scan = filter.child(); + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = + nonVirtualGroupByExprs(agg); + Set predicates = filter.getConjuncts(); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return agg.withChildren(repeat.withChildren(filter + .withChildren(scan.withPreAggStatus(preAggStatus)))); + }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_FILTER_SCAN), + + // Aggregate(Repeat(Project(Scan))) + logicalAggregate(logicalRepeat(logicalProject( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate>> agg = ctx.root; + LogicalRepeat> repeat = agg.child(); + LogicalProject project = repeat.child(); + LogicalOlapScan scan = project.child(); + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = + ExpressionUtils.replace(nonVirtualGroupByExprs(agg), + project.getAliasToProducer()); + Set predicates = ImmutableSet.of(); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return agg.withChildren(repeat.withChildren(project + .withChildren(scan.withPreAggStatus(preAggStatus)))); + }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_PROJECT_SCAN), + + // Aggregate(Repeat(Project(Filter(Scan)))) + logicalAggregate(logicalRepeat(logicalProject(logicalFilter( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate>>> agg + = ctx.root; + LogicalRepeat>> repeat = agg.child(); + LogicalProject> project = repeat.child(); + LogicalFilter filter = project.child(); + LogicalOlapScan scan = filter.child(); + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = + ExpressionUtils.replace(nonVirtualGroupByExprs(agg), + project.getAliasToProducer()); + Set predicates = filter.getConjuncts(); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return agg.withChildren(repeat + .withChildren(project.withChildren(filter.withChildren( + scan.withPreAggStatus(preAggStatus))))); + }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_PROJECT_FILTER_SCAN), + + // Aggregate(Repeat(Filter(Project(Scan)))) + logicalAggregate(logicalRepeat(logicalFilter(logicalProject( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate>>> agg + = ctx.root; + LogicalRepeat>> repeat = agg.child(); + LogicalFilter> filter = repeat.child(); + LogicalProject project = filter.child(); + LogicalOlapScan scan = project.child(); + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.of(project)); + List groupByExpressions = + ExpressionUtils.replace(nonVirtualGroupByExprs(agg), + project.getAliasToProducer()); + Set predicates = ExpressionUtils.replace( + filter.getConjuncts(), project.getAliasToProducer()); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return agg.withChildren(repeat + .withChildren(filter.withChildren(project.withChildren( + scan.withPreAggStatus(preAggStatus))))); + }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_FILTER_PROJECT_SCAN), + + // Filter(Project(Scan)) + logicalFilter(logicalProject( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))) + .thenApplyNoThrow(ctx -> { + LogicalFilter> filter = ctx.root; + LogicalProject project = filter.child(); + LogicalOlapScan scan = project.child(); + List aggregateFunctions = ImmutableList.of(); + List groupByExpressions = ImmutableList.of(); + Set predicates = ExpressionUtils.replace( + filter.getConjuncts(), project.getAliasToProducer()); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return filter.withChildren(project + .withChildren(scan.withPreAggStatus(preAggStatus))); + }).toRule(RuleType.PREAGG_STATUS_FILTER_PROJECT_SCAN), + + // Filter(Scan) + logicalFilter(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)) + .thenApplyNoThrow(ctx -> { + LogicalFilter filter = ctx.root; + LogicalOlapScan scan = filter.child(); + List aggregateFunctions = ImmutableList.of(); + List groupByExpressions = ImmutableList.of(); + Set predicates = filter.getConjuncts(); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return filter.withChildren(scan.withPreAggStatus(preAggStatus)); + }).toRule(RuleType.PREAGG_STATUS_FILTER_SCAN), + + // only scan. + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet) + .thenApplyNoThrow(ctx -> { + LogicalOlapScan scan = ctx.root; + List aggregateFunctions = ImmutableList.of(); + List groupByExpressions = ImmutableList.of(); + Set predicates = ImmutableSet.of(); + PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + return scan.withPreAggStatus(preAggStatus); + }).toRule(RuleType.PREAGG_STATUS_SCAN)); + } + + /////////////////////////////////////////////////////////////////////////// + // Set pre-aggregation status. + /////////////////////////////////////////////////////////////////////////// + + /** + * Do aggregate function extraction and replace aggregate function's input slots by underlying project. + *

+ * 1. extract aggregate functions in aggregate plan. + *

+ * 2. replace aggregate function's input slot by underlying project expression if project is present. + *

+ * For example: + *

+     * input arguments:
+     * agg: Aggregate(sum(v) as sum_value)
+     * underlying project: Project(a + b as v)
+     *
+     * output:
+     * sum(a + b)
+     * 
+ */ + private List extractAggFunctionAndReplaceSlot(LogicalAggregate agg, + Optional> project) { + Optional> slotToProducerOpt = + project.map(Project::getAliasToProducer); + return agg.getOutputExpressions().stream() + // extract aggregate functions. + .flatMap(e -> e.>collect(AggregateFunction.class::isInstance) + .stream()) + // replace aggregate function's input slot by its producing expression. + .map(expr -> slotToProducerOpt + .map(slotToExpressions -> (AggregateFunction) ExpressionUtils.replace(expr, + slotToExpressions)) + .orElse(expr)) + .collect(Collectors.toList()); + } + + private PreAggStatus checkPreAggStatus(LogicalOlapScan olapScan, Set predicates, + List aggregateFuncs, List groupingExprs) { + long selectIndexId = olapScan.getSelectedIndexId(); + MaterializedIndexMeta meta = olapScan.getTable().getIndexMetaByIndexId(selectIndexId); + if (meta.getKeysType() == KeysType.DUP_KEYS || (meta.getKeysType() == KeysType.UNIQUE_KEYS + && olapScan.getTable().getEnableUniqueKeyMergeOnWrite())) { + return PreAggStatus.on(); + } + Set outputSlots = olapScan.getOutputSet(); + Pair, Set> splittedSlots = splitSlots(outputSlots); + Set keySlots = splittedSlots.first; + Set valueSlots = splittedSlots.second; + Preconditions.checkState(outputSlots.size() == keySlots.size() + valueSlots.size(), + "output slots contains no key or value slots"); + + Set groupingExprsInputSlots = ExpressionUtils.getInputSlotSet(groupingExprs); + if (groupingExprsInputSlots.retainAll(keySlots)) { + return PreAggStatus + .off(String.format("Grouping expression %s contains non-key column %s", + groupingExprs, groupingExprsInputSlots)); + } + + Set predicateInputSlots = ExpressionUtils.getInputSlotSet(predicates); + if (predicateInputSlots.retainAll(keySlots)) { + return PreAggStatus.off(String.format("Predicate %s contains non-key column %s", + predicates, predicateInputSlots)); + } + + return checkAggregateFunctions(aggregateFuncs, groupingExprsInputSlots); + } + + private Pair, Set> splitSlots(Set slots) { + Set keySlots = Sets.newHashSetWithExpectedSize(slots.size()); + Set valueSlots = Sets.newHashSetWithExpectedSize(slots.size()); + for (Slot slot : slots) { + if (slot instanceof SlotReference && ((SlotReference) slot).getColumn().isPresent()) { + if (((SlotReference) slot).getColumn().get().isKey()) { + keySlots.add((SlotReference) slot); + } else { + valueSlots.add((SlotReference) slot); + } + } + } + return Pair.of(keySlots, valueSlots); + } + + private PreAggStatus checkAggWithKeyAndValueSlots(AggregateFunction aggFunc, + Set keySlots, Set valueSlots) { + Expression child = aggFunc.child(0); + List conditionExps = new ArrayList<>(); + List returnExps = new ArrayList<>(); + + // ignore cast + while (child instanceof Cast) { + if (!((Cast) child).getDataType().isNumericType()) { + return PreAggStatus.off(String.format("%s is not numeric CAST.", child.toSql())); + } + child = child.child(0); + } + // step 1: extract all condition exprs and return exprs + if (child instanceof If) { + conditionExps.add(child.child(0)); + returnExps.add(removeCast(child.child(1))); + returnExps.add(removeCast(child.child(2))); + } else if (child instanceof CaseWhen) { + CaseWhen caseWhen = (CaseWhen) child; + // WHEN THEN + for (WhenClause whenClause : caseWhen.getWhenClauses()) { + conditionExps.add(whenClause.getOperand()); + returnExps.add(removeCast(whenClause.getResult())); + } + // ELSE + returnExps.add(removeCast(caseWhen.getDefaultValue().orElse(new NullLiteral()))); + } else { + // currently, only IF and CASE WHEN are supported + returnExps.add(removeCast(child)); + } + + // step 2: check condition expressions + Set inputSlots = ExpressionUtils.getInputSlotSet(conditionExps); + inputSlots.retainAll(valueSlots); + if (!inputSlots.isEmpty()) { + return PreAggStatus + .off(String.format("some columns in condition %s is not key.", conditionExps)); + } + + return KeyAndValueSlotsAggChecker.INSTANCE.check(aggFunc, returnExps); + } + + private PreAggStatus checkAggregateFunctions(List aggregateFuncs, + Set groupingExprsInputSlots) { + PreAggStatus preAggStatus = aggregateFuncs.isEmpty() && groupingExprsInputSlots.isEmpty() + ? PreAggStatus.off("No aggregate on scan.") + : PreAggStatus.on(); + for (AggregateFunction aggFunc : aggregateFuncs) { + if (aggFunc.children().size() == 1 && aggFunc.child(0) instanceof Slot) { + Slot aggSlot = (Slot) aggFunc.child(0); + if (aggSlot instanceof SlotReference + && ((SlotReference) aggSlot).getColumn().isPresent()) { + if (((SlotReference) aggSlot).getColumn().get().isKey()) { + preAggStatus = OneKeySlotAggChecker.INSTANCE.check(aggFunc); + } else { + preAggStatus = OneValueSlotAggChecker.INSTANCE.check(aggFunc, + ((SlotReference) aggSlot).getColumn().get().getAggregationType()); + } + } else { + preAggStatus = PreAggStatus.off( + String.format("aggregate function %s use unknown slot %s from scan", + aggFunc, aggSlot)); + } + } else { + Set aggSlots = aggFunc.getInputSlots(); + Pair, Set> splitSlots = splitSlots(aggSlots); + preAggStatus = + checkAggWithKeyAndValueSlots(aggFunc, splitSlots.first, splitSlots.second); + } + if (preAggStatus.isOff()) { + return preAggStatus; + } + } + return preAggStatus; + } + + private List nonVirtualGroupByExprs(LogicalAggregate agg) { + return agg.getGroupByExpressions().stream() + .filter(expr -> !(expr instanceof VirtualSlotReference)) + .collect(ImmutableList.toImmutableList()); + } + + /** + * eg: select abs(k1)+1 t,sum(abs(k2+1)) from single_slot group by t order by t; + * +--LogicalAggregate[88] ( groupByExpr=[t#4], outputExpr=[t#4, sum(abs((k2#1 + 1))) AS `sum(abs(k2 + 1))`#5]) + * +--LogicalProject[87] ( distinct=false, projects=[(abs(k1#0) + 1) AS `t`#4, k2#1]) + * +--LogicalOlapScan() + * t -> abs(k1#0) + 1 + */ + private Set collectRequireExprWithAggAndProject( + List aggExpressions, Optional> project) { + List projectExpressions = + project.isPresent() ? project.get().getProjects() : null; + if (projectExpressions == null) { + return aggExpressions.stream().collect(ImmutableSet.toImmutableSet()); + } + Optional> slotToProducerOpt = + project.map(Project::getAliasToProducer); + Map exprIdToExpression = projectExpressions.stream() + .collect(Collectors.toMap(NamedExpression::getExprId, e -> { + if (e instanceof Alias) { + return ((Alias) e).child(); + } + return e; + })); + return aggExpressions.stream().map(e -> { + if ((e instanceof NamedExpression) + && exprIdToExpression.containsKey(((NamedExpression) e).getExprId())) { + return exprIdToExpression.get(((NamedExpression) e).getExprId()); + } + return e; + }).map(e -> { + return slotToProducerOpt + .map(slotToExpressions -> ExpressionUtils.replace(e, slotToExpressions)) + .orElse(e); + }).collect(ImmutableSet.toImmutableSet()); + } + + private static class OneValueSlotAggChecker + extends ExpressionVisitor { + public static final OneValueSlotAggChecker INSTANCE = new OneValueSlotAggChecker(); + + public PreAggStatus check(AggregateFunction aggFun, AggregateType aggregateType) { + return aggFun.accept(INSTANCE, aggregateType); + } + + @Override + public PreAggStatus visit(Expression expr, AggregateType aggregateType) { + return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql())); + } + + @Override + public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, + AggregateType aggregateType) { + return PreAggStatus + .off(String.format("%s is not supported.", aggregateFunction.toSql())); + } + + @Override + public PreAggStatus visitMax(Max max, AggregateType aggregateType) { + if (aggregateType == AggregateType.MAX && !max.isDistinct()) { + return PreAggStatus.on(); + } else { + return PreAggStatus + .off(String.format("%s is not match agg mode %s or has distinct param", + max.toSql(), aggregateType)); + } + } + + @Override + public PreAggStatus visitMin(Min min, AggregateType aggregateType) { + if (aggregateType == AggregateType.MIN && !min.isDistinct()) { + return PreAggStatus.on(); + } else { + return PreAggStatus + .off(String.format("%s is not match agg mode %s or has distinct param", + min.toSql(), aggregateType)); + } + } + + @Override + public PreAggStatus visitSum(Sum sum, AggregateType aggregateType) { + if (aggregateType == AggregateType.SUM && !sum.isDistinct()) { + return PreAggStatus.on(); + } else { + return PreAggStatus + .off(String.format("%s is not match agg mode %s or has distinct param", + sum.toSql(), aggregateType)); + } + } + + @Override + public PreAggStatus visitBitmapUnionCount(BitmapUnionCount bitmapUnionCount, + AggregateType aggregateType) { + if (aggregateType == AggregateType.BITMAP_UNION) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off("invalid bitmap_union_count: " + bitmapUnionCount.toSql()); + } + } + + @Override + public PreAggStatus visitBitmapUnion(BitmapUnion bitmapUnion, AggregateType aggregateType) { + if (aggregateType == AggregateType.BITMAP_UNION) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off("invalid bitmapUnion: " + bitmapUnion.toSql()); + } + } + + @Override + public PreAggStatus visitHllUnionAgg(HllUnionAgg hllUnionAgg, AggregateType aggregateType) { + if (aggregateType == AggregateType.HLL_UNION) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off("invalid hllUnionAgg: " + hllUnionAgg.toSql()); + } + } + + @Override + public PreAggStatus visitHllUnion(HllUnion hllUnion, AggregateType aggregateType) { + if (aggregateType == AggregateType.HLL_UNION) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off("invalid hllUnion: " + hllUnion.toSql()); + } + } + } + + private static class OneKeySlotAggChecker extends ExpressionVisitor { + public static final OneKeySlotAggChecker INSTANCE = new OneKeySlotAggChecker(); + + public PreAggStatus check(AggregateFunction aggFun) { + return aggFun.accept(INSTANCE, null); + } + + @Override + public PreAggStatus visit(Expression expr, Void context) { + return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql())); + } + + @Override + public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, + Void context) { + return PreAggStatus.off(String.format("Aggregate function %s contains key column %s", + aggregateFunction.toSql(), aggregateFunction.child(0).toSql())); + } + + @Override + public PreAggStatus visitMax(Max max, Void context) { + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitMin(Min min, Void context) { + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitCount(Count count, Void context) { + if (count.isDistinct()) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off(String.format("%s is not distinct.", count.toSql())); + } + } + } + + private static class KeyAndValueSlotsAggChecker + extends ExpressionVisitor> { + public static final KeyAndValueSlotsAggChecker INSTANCE = new KeyAndValueSlotsAggChecker(); + + public PreAggStatus check(AggregateFunction aggFun, List returnValues) { + return aggFun.accept(INSTANCE, returnValues); + } + + @Override + public PreAggStatus visit(Expression expr, List returnValues) { + return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql())); + } + + @Override + public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, + List returnValues) { + return PreAggStatus + .off(String.format("%s is not supported.", aggregateFunction.toSql())); + } + + @Override + public PreAggStatus visitSum(Sum sum, List returnValues) { + for (Expression value : returnValues) { + if (!(isAggTypeMatched(value, AggregateType.SUM) || value.isZeroLiteral() + || value.isNullLiteral())) { + return PreAggStatus.off(String.format("%s is not supported.", sum.toSql())); + } + } + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitMax(Max max, List returnValues) { + for (Expression value : returnValues) { + if (!(isAggTypeMatched(value, AggregateType.MAX) || isKeySlot(value) + || value.isNullLiteral())) { + return PreAggStatus.off(String.format("%s is not supported.", max.toSql())); + } + } + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitMin(Min min, List returnValues) { + for (Expression value : returnValues) { + if (!(isAggTypeMatched(value, AggregateType.MIN) || isKeySlot(value) + || value.isNullLiteral())) { + return PreAggStatus.off(String.format("%s is not supported.", min.toSql())); + } + } + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitCount(Count count, List returnValues) { + if (count.isDistinct()) { + for (Expression value : returnValues) { + if (!(isKeySlot(value) || value.isZeroLiteral() || value.isNullLiteral())) { + return PreAggStatus + .off(String.format("%s is not supported.", count.toSql())); + } + } + return PreAggStatus.on(); + } else { + return PreAggStatus.off(String.format("%s is not supported.", count.toSql())); + } + } + + private boolean isKeySlot(Expression expression) { + return expression instanceof SlotReference + && ((SlotReference) expression).getColumn().isPresent() + && ((SlotReference) expression).getColumn().get().isKey(); + } + + private boolean isAggTypeMatched(Expression expression, AggregateType aggregateType) { + return expression instanceof SlotReference + && ((SlotReference) expression).getColumn().isPresent() + && ((SlotReference) expression).getColumn().get() + .getAggregationType() == aggregateType; + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java index b5773a7571d24e..1124c141416f3f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java @@ -88,17 +88,7 @@ protected boolean shouldSelectIndexWithAgg(LogicalOlapScan scan) { case AGG_KEYS: case UNIQUE_KEYS: case DUP_KEYS: - // SelectMaterializedIndexWithAggregate(R1) run before SelectMaterializedIndexWithoutAggregate(R2) - // if R1 selects baseIndex and preAggStatus is off - // we should give a chance to R2 to check if some prefix-index can be selected - // so if R1 selects baseIndex and preAggStatus is off, we keep scan's index unselected in order to - // let R2 to get a chance to do its work - // at last, after R1, the scan may be the 4 status - // 1. preAggStatus is ON and baseIndex is selected, it means select baseIndex is correct. - // 2. preAggStatus is ON and some other Index is selected, this is correct, too. - // 3. preAggStatus is OFF, no index is selected, it means R2 could get a chance to run - // so we check the preAggStatus and if some index is selected to make sure R1 can be run only once - return scan.getPreAggStatus().isOn() && !scan.isIndexSelected(); + return !scan.isIndexSelected(); default: return false; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java index 324f5ae4ef1bd3..b221637f18794a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java @@ -39,7 +39,6 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren; -import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; @@ -77,7 +76,6 @@ import org.apache.doris.nereids.util.Utils; import org.apache.doris.planner.PlanNode; -import com.google.common.base.Preconditions; import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -213,7 +211,7 @@ public List buildRules() { result.exprRewriteMap.projectExprMap); LogicalProject newProject = new LogicalProject<>( generateNewOutputsWithMvOutputs(mvPlan, newProjectList), - scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId)); + scan.withMaterializedIndexSelected(result.indexId)); return new LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext), new ReplaceExpressions(slotContext) .replace( @@ -261,9 +259,6 @@ public List buildRules() { filter.getExpressions(), project.getExpressions() )) ); - if (mvPlanWithoutAgg.getSelectedIndexId() == result.indexId) { - mvPlanWithoutAgg = mvPlanWithoutAgg.withPreAggStatus(result.preAggStatus); - } SlotContext slotContextWithoutAgg = generateBaseScanExprToMvExpr(mvPlanWithoutAgg); return agg.withChildren(new LogicalProject( @@ -537,7 +532,7 @@ public List buildRules() { result.exprRewriteMap.projectExprMap); LogicalProject newProject = new LogicalProject<>( generateNewOutputsWithMvOutputs(mvPlan, newProjectList), - scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId)); + scan.withMaterializedIndexSelected(result.indexId)); return new LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext), new ReplaceExpressions(slotContext).replace(new LogicalAggregate<>( @@ -554,16 +549,7 @@ public List buildRules() { } private static LogicalOlapScan createLogicalOlapScan(LogicalOlapScan scan, SelectResult result) { - LogicalOlapScan mvPlan; - if (result.preAggStatus.isOff()) { - // we only set preAggStatus and make index unselected to let SelectMaterializedIndexWithoutAggregate - // have a chance to run and select proper index - mvPlan = scan.withPreAggStatus(result.preAggStatus); - } else { - mvPlan = - scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId); - } - return mvPlan; + return scan.withMaterializedIndexSelected(result.indexId); } /////////////////////////////////////////////////////////////////////////// @@ -635,7 +621,8 @@ private SelectResult select(LogicalOlapScan scan, Set requiredScanOutput, if ((new CheckContext(scan, selectIndexId)).isBaseIndex()) { PreAggStatus preagg = scan.getPreAggStatus(); if (preagg.isOn()) { - preagg = checkPreAggStatus(scan, predicates, aggregateFunctions, groupingExprs); + preagg = checkPreAggStatus(scan, scan.getTable().getBaseIndexId(), predicates, aggregateFunctions, + groupingExprs); } return new SelectResult(preagg, selectIndexId, new ExprRewriteMap()); } @@ -717,346 +704,6 @@ private PreAggStatus checkPreAggStatus(LogicalOlapScan olapScan, long indexId, S .offOrElse(() -> checkPredicates(ImmutableList.copyOf(predicates), checkContext)); } - private PreAggStatus checkPreAggStatus(LogicalOlapScan olapScan, Set predicates, - List aggregateFuncs, List groupingExprs) { - MaterializedIndexMeta meta = - olapScan.getTable().getIndexMetaByIndexId(olapScan.getSelectedIndexId()); - if (meta.getKeysType() == KeysType.DUP_KEYS || (meta.getKeysType() == KeysType.UNIQUE_KEYS - && olapScan.getTable().getEnableUniqueKeyMergeOnWrite())) { - return PreAggStatus.on(); - } - Set outputSlots = olapScan.getOutputSet(); - Pair, Set> splittedSlots = splitSlots(outputSlots); - Set keySlots = splittedSlots.first; - Set valueSlots = splittedSlots.second; - Preconditions.checkState(outputSlots.size() == keySlots.size() + valueSlots.size(), - "output slots contains no key or value slots"); - - Set groupInputSlots = ExpressionUtils.getInputSlotSet(groupingExprs); - if (groupInputSlots.retainAll(keySlots)) { - return PreAggStatus - .off(String.format("Grouping expression %s contains non-key column %s", - groupingExprs, groupInputSlots)); - } - - Set predicateInputSlots = ExpressionUtils.getInputSlotSet(predicates); - if (predicateInputSlots.retainAll(keySlots)) { - return PreAggStatus.off(String.format("Predicate %s contains non-key column %s", - predicates, predicateInputSlots)); - } - - return checkAggregateFunctions(aggregateFuncs, keySlots, valueSlots); - } - - private Pair, Set> splitSlots(Set slots) { - Set keySlots = Sets.newHashSetWithExpectedSize(slots.size()); - Set valueSlots = Sets.newHashSetWithExpectedSize(slots.size()); - for (Slot slot : slots) { - if (slot instanceof SlotReference && ((SlotReference) slot).getColumn().isPresent()) { - if (((SlotReference) slot).getColumn().get().isKey()) { - keySlots.add((SlotReference) slot); - } else { - valueSlots.add((SlotReference) slot); - } - } - } - return Pair.of(keySlots, valueSlots); - } - - private static class OneValueSlotAggChecker - extends ExpressionVisitor { - public static final OneValueSlotAggChecker INSTANCE = new OneValueSlotAggChecker(); - - public PreAggStatus check(AggregateFunction aggFun, AggregateType aggregateType) { - return aggFun.accept(INSTANCE, aggregateType); - } - - @Override - public PreAggStatus visit(Expression expr, AggregateType aggregateType) { - return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql())); - } - - @Override - public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, - AggregateType aggregateType) { - return PreAggStatus - .off(String.format("%s is not supported.", aggregateFunction.toSql())); - } - - @Override - public PreAggStatus visitMax(Max max, AggregateType aggregateType) { - if (aggregateType == AggregateType.MAX && !max.isDistinct()) { - return PreAggStatus.on(); - } else { - return PreAggStatus - .off(String.format("%s is not match agg mode %s or has distinct param", - max.toSql(), aggregateType)); - } - } - - @Override - public PreAggStatus visitMin(Min min, AggregateType aggregateType) { - if (aggregateType == AggregateType.MIN && !min.isDistinct()) { - return PreAggStatus.on(); - } else { - return PreAggStatus - .off(String.format("%s is not match agg mode %s or has distinct param", - min.toSql(), aggregateType)); - } - } - - @Override - public PreAggStatus visitSum(Sum sum, AggregateType aggregateType) { - if (aggregateType == AggregateType.SUM && !sum.isDistinct()) { - return PreAggStatus.on(); - } else { - return PreAggStatus - .off(String.format("%s is not match agg mode %s or has distinct param", - sum.toSql(), aggregateType)); - } - } - - @Override - public PreAggStatus visitBitmapUnionCount(BitmapUnionCount bitmapUnionCount, - AggregateType aggregateType) { - if (aggregateType == AggregateType.BITMAP_UNION) { - return PreAggStatus.on(); - } else { - return PreAggStatus.off("invalid bitmap_union_count: " + bitmapUnionCount.toSql()); - } - } - - @Override - public PreAggStatus visitBitmapUnion(BitmapUnion bitmapUnion, AggregateType aggregateType) { - if (aggregateType == AggregateType.BITMAP_UNION) { - return PreAggStatus.on(); - } else { - return PreAggStatus.off("invalid bitmapUnion: " + bitmapUnion.toSql()); - } - } - - @Override - public PreAggStatus visitHllUnionAgg(HllUnionAgg hllUnionAgg, AggregateType aggregateType) { - if (aggregateType == AggregateType.HLL_UNION) { - return PreAggStatus.on(); - } else { - return PreAggStatus.off("invalid hllUnionAgg: " + hllUnionAgg.toSql()); - } - } - - @Override - public PreAggStatus visitHllUnion(HllUnion hllUnion, AggregateType aggregateType) { - if (aggregateType == AggregateType.HLL_UNION) { - return PreAggStatus.on(); - } else { - return PreAggStatus.off("invalid hllUnion: " + hllUnion.toSql()); - } - } - } - - private static class OneKeySlotAggChecker extends ExpressionVisitor { - public static final OneKeySlotAggChecker INSTANCE = new OneKeySlotAggChecker(); - - public PreAggStatus check(AggregateFunction aggFun) { - return aggFun.accept(INSTANCE, null); - } - - @Override - public PreAggStatus visit(Expression expr, Void context) { - return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql())); - } - - @Override - public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, - Void context) { - return PreAggStatus.off(String.format("Aggregate function %s contains key column %s", - aggregateFunction.toSql(), aggregateFunction.child(0).toSql())); - } - - @Override - public PreAggStatus visitMax(Max max, Void context) { - return PreAggStatus.on(); - } - - @Override - public PreAggStatus visitMin(Min min, Void context) { - return PreAggStatus.on(); - } - - @Override - public PreAggStatus visitCount(Count count, Void context) { - if (count.isDistinct()) { - return PreAggStatus.on(); - } else { - return PreAggStatus.off(String.format("%s is not distinct.", count.toSql())); - } - } - } - - private static class KeyAndValueSlotsAggChecker - extends ExpressionVisitor> { - public static final KeyAndValueSlotsAggChecker INSTANCE = new KeyAndValueSlotsAggChecker(); - - public PreAggStatus check(AggregateFunction aggFun, List returnValues) { - return aggFun.accept(INSTANCE, returnValues); - } - - @Override - public PreAggStatus visit(Expression expr, List returnValues) { - return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql())); - } - - @Override - public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, - List returnValues) { - return PreAggStatus - .off(String.format("%s is not supported.", aggregateFunction.toSql())); - } - - @Override - public PreAggStatus visitSum(Sum sum, List returnValues) { - for (Expression value : returnValues) { - if (!(isAggTypeMatched(value, AggregateType.SUM) || value.isZeroLiteral() - || value.isNullLiteral())) { - return PreAggStatus.off(String.format("%s is not supported.", sum.toSql())); - } - } - return PreAggStatus.on(); - } - - @Override - public PreAggStatus visitMax(Max max, List returnValues) { - for (Expression value : returnValues) { - if (!(isAggTypeMatched(value, AggregateType.MAX) || isKeySlot(value) - || value.isNullLiteral())) { - return PreAggStatus.off(String.format("%s is not supported.", max.toSql())); - } - } - return PreAggStatus.on(); - } - - @Override - public PreAggStatus visitMin(Min min, List returnValues) { - for (Expression value : returnValues) { - if (!(isAggTypeMatched(value, AggregateType.MIN) || isKeySlot(value) - || value.isNullLiteral())) { - return PreAggStatus.off(String.format("%s is not supported.", min.toSql())); - } - } - return PreAggStatus.on(); - } - - @Override - public PreAggStatus visitCount(Count count, List returnValues) { - if (count.isDistinct()) { - for (Expression value : returnValues) { - if (!(isKeySlot(value) || value.isZeroLiteral() || value.isNullLiteral())) { - return PreAggStatus - .off(String.format("%s is not supported.", count.toSql())); - } - } - return PreAggStatus.on(); - } else { - return PreAggStatus.off(String.format("%s is not supported.", count.toSql())); - } - } - - private boolean isKeySlot(Expression expression) { - return expression instanceof SlotReference - && ((SlotReference) expression).getColumn().isPresent() - && ((SlotReference) expression).getColumn().get().isKey(); - } - - private boolean isAggTypeMatched(Expression expression, AggregateType aggregateType) { - return expression instanceof SlotReference - && ((SlotReference) expression).getColumn().isPresent() - && ((SlotReference) expression).getColumn().get() - .getAggregationType() == aggregateType; - } - } - - private static Expression removeCast(Expression expression) { - while (expression instanceof Cast) { - expression = ((Cast) expression).child(); - } - return expression; - } - - private PreAggStatus checkAggWithKeyAndValueSlots(AggregateFunction aggFunc, - Set keySlots, Set valueSlots) { - Expression child = aggFunc.child(0); - List conditionExps = new ArrayList<>(); - List returnExps = new ArrayList<>(); - - // ignore cast - while (child instanceof Cast) { - if (!((Cast) child).getDataType().isNumericType()) { - return PreAggStatus.off(String.format("%s is not numeric CAST.", child.toSql())); - } - child = child.child(0); - } - // step 1: extract all condition exprs and return exprs - if (child instanceof If) { - conditionExps.add(child.child(0)); - returnExps.add(removeCast(child.child(1))); - returnExps.add(removeCast(child.child(2))); - } else if (child instanceof CaseWhen) { - CaseWhen caseWhen = (CaseWhen) child; - // WHEN THEN - for (WhenClause whenClause : caseWhen.getWhenClauses()) { - conditionExps.add(whenClause.getOperand()); - returnExps.add(removeCast(whenClause.getResult())); - } - // ELSE - returnExps.add(removeCast(caseWhen.getDefaultValue().orElse(new NullLiteral()))); - } else { - // currently, only IF and CASE WHEN are supported - returnExps.add(removeCast(child)); - } - - // step 2: check condition expressions - Set inputSlots = ExpressionUtils.getInputSlotSet(conditionExps); - inputSlots.retainAll(valueSlots); - if (!inputSlots.isEmpty()) { - return PreAggStatus - .off(String.format("some columns in condition %s is not key.", conditionExps)); - } - - return KeyAndValueSlotsAggChecker.INSTANCE.check(aggFunc, returnExps); - } - - private PreAggStatus checkAggregateFunctions(List aggregateFuncs, - Set keySlots, Set valueSlots) { - PreAggStatus preAggStatus = PreAggStatus.on(); - for (AggregateFunction aggFunc : aggregateFuncs) { - if (aggFunc.children().size() == 1 && aggFunc.child(0) instanceof Slot) { - Slot aggSlot = (Slot) aggFunc.child(0); - if (aggSlot instanceof SlotReference - && ((SlotReference) aggSlot).getColumn().isPresent()) { - if (((SlotReference) aggSlot).getColumn().get().isKey()) { - preAggStatus = OneKeySlotAggChecker.INSTANCE.check(aggFunc); - } else { - preAggStatus = OneValueSlotAggChecker.INSTANCE.check(aggFunc, - ((SlotReference) aggSlot).getColumn().get().getAggregationType()); - } - } else { - preAggStatus = PreAggStatus.off( - String.format("aggregate function %s use unknown slot %s from scan", - aggFunc, aggSlot)); - } - } else { - Set aggSlots = aggFunc.getInputSlots(); - Pair, Set> splitSlots = splitSlots(aggSlots); - preAggStatus = - checkAggWithKeyAndValueSlots(aggFunc, splitSlots.first, splitSlots.second); - } - if (preAggStatus.isOff()) { - return preAggStatus; - } - } - return preAggStatus; - } - /** * Check pre agg status according to aggregate functions. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java index acffdc3b258052..5e4e1ce44c92dd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java @@ -27,7 +27,6 @@ import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.plans.PreAggStatus; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -185,7 +184,7 @@ public static LogicalOlapScan select( break; case DUP_KEYS: if (table.getIndexIdToMeta().size() == 1) { - return scan.withMaterializedIndexSelected(PreAggStatus.on(), baseIndexId); + return scan.withMaterializedIndexSelected(baseIndexId); } break; default: @@ -210,19 +209,10 @@ public static LogicalOlapScan select( // this is fail-safe for select mv // select baseIndex if bestIndex's slots' data types are different from baseIndex bestIndex = isSameDataType(scan, bestIndex, requiredSlots.get()) ? bestIndex : baseIndexId; - return scan.withMaterializedIndexSelected(PreAggStatus.on(), bestIndex); + return scan.withMaterializedIndexSelected(bestIndex); } else { - final PreAggStatus preAggStatus; - if (preAggEnabledByHint(scan)) { - // PreAggStatus could be enabled by pre-aggregation hint for agg-keys and unique-keys. - preAggStatus = PreAggStatus.on(); - } else { - // if PreAggStatus is OFF, we use the message from SelectMaterializedIndexWithAggregate - preAggStatus = scan.getPreAggStatus().isOff() ? scan.getPreAggStatus() - : PreAggStatus.off("No aggregate on scan."); - } if (table.getIndexIdToMeta().size() == 1) { - return scan.withMaterializedIndexSelected(preAggStatus, baseIndexId); + return scan.withMaterializedIndexSelected(baseIndexId); } int baseIndexKeySize = table.getKeyColumnsByIndexId(table.getBaseIndexId()).size(); // No aggregate on scan. @@ -235,13 +225,13 @@ public static LogicalOlapScan select( if (candidates.size() == 1) { // `candidates` only have base index. - return scan.withMaterializedIndexSelected(preAggStatus, baseIndexId); + return scan.withMaterializedIndexSelected(baseIndexId); } else { long bestIndex = selectBestIndex(candidates, scan, predicatesSupplier.get(), requiredExpr.get()); // this is fail-safe for select mv // select baseIndex if bestIndex's slots' data types are different from baseIndex bestIndex = isSameDataType(scan, bestIndex, requiredSlots.get()) ? bestIndex : baseIndexId; - return scan.withMaterializedIndexSelected(preAggStatus, bestIndex); + return scan.withMaterializedIndexSelected(bestIndex); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PreAggStatus.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PreAggStatus.java index 7affac49b2bc09..8ba99c2c07f0eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PreAggStatus.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PreAggStatus.java @@ -26,10 +26,11 @@ public class PreAggStatus { private enum Status { - ON, OFF + ON, OFF, UNSET } private static final PreAggStatus PRE_AGG_ON = new PreAggStatus(Status.ON, ""); + private static final PreAggStatus PRE_AGG_UNSET = new PreAggStatus(Status.UNSET, ""); private final Status status; private final String offReason; @@ -46,6 +47,10 @@ public boolean isOff() { return status == Status.OFF; } + public boolean isUnset() { + return status == Status.UNSET; + } + public String getOffReason() { return offReason; } @@ -58,6 +63,10 @@ public PreAggStatus offOrElse(Supplier supplier) { } } + public static PreAggStatus unset() { + return PRE_AGG_UNSET; + } + public static PreAggStatus on() { return PRE_AGG_ON; } @@ -70,8 +79,10 @@ public static PreAggStatus off(String reason) { public String toString() { if (status == Status.ON) { return "ON"; - } else { + } else if (status == Status.OFF) { return "OFF, " + offReason; + } else { + return "UNSET"; } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java index d0d91f1cf8dafb..714f540524f1a8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java @@ -126,7 +126,7 @@ public LogicalOlapScan(RelationId id, OlapTable table, List qualifier) { this(id, table, qualifier, Optional.empty(), Optional.empty(), table.getPartitionIds(), false, ImmutableList.of(), - -1, false, PreAggStatus.on(), ImmutableList.of(), ImmutableList.of(), + -1, false, PreAggStatus.unset(), ImmutableList.of(), ImmutableList.of(), Maps.newHashMap(), Optional.empty(), false, false); } @@ -134,7 +134,7 @@ public LogicalOlapScan(RelationId id, OlapTable table, List qualifier, L List hints, Optional tableSample) { this(id, table, qualifier, Optional.empty(), Optional.empty(), table.getPartitionIds(), false, tabletIds, - -1, false, PreAggStatus.on(), ImmutableList.of(), hints, Maps.newHashMap(), + -1, false, PreAggStatus.unset(), ImmutableList.of(), hints, Maps.newHashMap(), tableSample, false, false); } @@ -143,7 +143,7 @@ public LogicalOlapScan(RelationId id, OlapTable table, List qualifier, L this(id, table, qualifier, Optional.empty(), Optional.empty(), // must use specifiedPartitions here for prune partition by sql like 'select * from t partition p1' specifiedPartitions, false, tabletIds, - -1, false, PreAggStatus.on(), specifiedPartitions, hints, Maps.newHashMap(), + -1, false, PreAggStatus.unset(), specifiedPartitions, hints, Maps.newHashMap(), tableSample, false, false); } @@ -275,11 +275,11 @@ public LogicalOlapScan withSelectedPartitionIds(List selectedPartitionIds) hints, cacheSlotWithSlotName, tableSample, directMvScan, projectPulledUp); } - public LogicalOlapScan withMaterializedIndexSelected(PreAggStatus preAgg, long indexId) { + public LogicalOlapScan withMaterializedIndexSelected(long indexId) { return new LogicalOlapScan(relationId, (Table) table, qualifier, Optional.empty(), Optional.of(getLogicalProperties()), selectedPartitionIds, partitionPruned, selectedTabletIds, - indexId, true, preAgg, manuallySpecifiedPartitions, hints, cacheSlotWithSlotName, + indexId, true, PreAggStatus.unset(), manuallySpecifiedPartitions, hints, cacheSlotWithSlotName, tableSample, directMvScan, projectPulledUp); } @@ -432,6 +432,10 @@ public boolean isDirectMvScan() { return directMvScan; } + public boolean isPreAggStatusUnSet() { + return preAggStatus.isUnset(); + } + private List createSlotsVectorized(List columns) { List qualified = qualified(); Object[] slots = new Object[columns.size()]; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java index beb9029e773508..45552bfc2fae6e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java @@ -19,6 +19,7 @@ import org.apache.doris.common.FeConstants; import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject; +import org.apache.doris.nereids.rules.rewrite.AdjustPreAggStatus; import org.apache.doris.nereids.rules.rewrite.MergeProjects; import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject; import org.apache.doris.nereids.trees.plans.PreAggStatus; @@ -110,6 +111,7 @@ public void testMatchingBase() { PlanChecker.from(connectContext) .analyze(" select k1, sum(v1) from t group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); Assertions.assertEquals("t", scan.getSelectedMaterializedIndexName().get()); @@ -122,6 +124,7 @@ void testAggFilterScan() { PlanChecker.from(connectContext) .analyze("select k2, sum(v1) from t where k3=0 group by k2") .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); @@ -149,6 +152,7 @@ public void testWithEqualFilter() { PlanChecker.from(connectContext) .analyze("select k2, sum(v1) from t where k3=0 group by k2") .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); @@ -161,6 +165,7 @@ public void testWithNonEqualFilter() { PlanChecker.from(connectContext) .analyze("select k2, sum(v1) from t where k3>0 group by k2") .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); @@ -173,6 +178,7 @@ public void testWithFilter() { PlanChecker.from(connectContext) .analyze("select k2, sum(v1) from t where k2>3 group by k2") .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); Assertions.assertEquals("r1", scan.getSelectedMaterializedIndexName().get()); @@ -192,6 +198,7 @@ public void testWithFilterAndProject() { .applyBottomUp(new MergeProjects()) .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); @@ -209,6 +216,7 @@ public void testNoAggregate() { .analyze("select k1, v1 from t") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); @@ -223,6 +231,7 @@ public void testAggregateTypeNotMatch() { .analyze("select k1, min(v1) from t group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); @@ -237,6 +246,7 @@ public void testInvalidSlotInAggFunction() { .analyze("select k1, sum(v1 + 1) from t group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); @@ -252,6 +262,7 @@ public void testKeyColumnInAggFunction() { .analyze("select k1, sum(k2) from t group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); @@ -267,6 +278,7 @@ public void testMaxCanUseKeyColumn() { .analyze("select k2, max(k3) from t group by k2") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -281,6 +293,7 @@ public void testMinCanUseKeyColumn() { .analyze("select k2, min(k3) from t group by k2") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -295,6 +308,7 @@ public void testMinMaxCanUseKeyColumnWithBaseTable() { .analyze("select k1, min(k2), max(k2) from t group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -309,6 +323,8 @@ public void testFilterAggWithBaseTable() { .analyze("select k1 from t where k1 = 0 group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new MergeProjects()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -323,6 +339,7 @@ public void testDuplicatePreAggOn() { .analyze("select k1, sum(k1) from duplicate_tbl group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -336,6 +353,7 @@ public void testDuplicatePreAggOnEvenWithoutAggregate() { .analyze("select k1, v1 from duplicate_tbl") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); From 65060bd7773e1889c44c4a27832a5a74fd4cb84e Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Tue, 14 May 2024 18:50:55 +0800 Subject: [PATCH 3/5] fix fe ut --- .../org/apache/doris/nereids/trees/plans/PlanToStringTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java index 44cb6c296af30d..0a6eb7e0c592ef 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java @@ -84,7 +84,7 @@ public void testLogicalOlapScan() { Assertions.assertTrue( plan.toString().matches("LogicalOlapScan \\( qualified=db\\.table, " + "indexName=, " - + "selectedIndexId=-1, preAgg=ON \\)")); + + "selectedIndexId=-1, preAgg=UNSET \\)")); } @Test From 550e07de3b257ed8fdd6174a46d2e1d79745e8e8 Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Thu, 16 May 2024 18:10:45 +0800 Subject: [PATCH 4/5] remove unused code --- .../rules/rewrite/AdjustPreAggStatus.java | 53 +++---------------- 1 file changed, 7 insertions(+), 46 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java index c867cb10b4ea64..d6ddbb5bcc6401 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java @@ -24,12 +24,9 @@ import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; -import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; @@ -73,13 +70,6 @@ */ @Developing public class AdjustPreAggStatus implements RewriteRuleFactory { - private static Expression removeCast(Expression expression) { - while (expression instanceof Cast) { - expression = ((Cast) expression).child(); - } - return expression; - } - /////////////////////////////////////////////////////////////////////////// // All the patterns /////////////////////////////////////////////////////////////////////////// @@ -410,6 +400,13 @@ private Pair, Set> splitSlots(Set slots) return Pair.of(keySlots, valueSlots); } + private static Expression removeCast(Expression expression) { + while (expression instanceof Cast) { + expression = ((Cast) expression).child(); + } + return expression; + } + private PreAggStatus checkAggWithKeyAndValueSlots(AggregateFunction aggFunc, Set keySlots, Set valueSlots) { Expression child = aggFunc.child(0); @@ -493,42 +490,6 @@ private List nonVirtualGroupByExprs(LogicalAggregate .collect(ImmutableList.toImmutableList()); } - /** - * eg: select abs(k1)+1 t,sum(abs(k2+1)) from single_slot group by t order by t; - * +--LogicalAggregate[88] ( groupByExpr=[t#4], outputExpr=[t#4, sum(abs((k2#1 + 1))) AS `sum(abs(k2 + 1))`#5]) - * +--LogicalProject[87] ( distinct=false, projects=[(abs(k1#0) + 1) AS `t`#4, k2#1]) - * +--LogicalOlapScan() - * t -> abs(k1#0) + 1 - */ - private Set collectRequireExprWithAggAndProject( - List aggExpressions, Optional> project) { - List projectExpressions = - project.isPresent() ? project.get().getProjects() : null; - if (projectExpressions == null) { - return aggExpressions.stream().collect(ImmutableSet.toImmutableSet()); - } - Optional> slotToProducerOpt = - project.map(Project::getAliasToProducer); - Map exprIdToExpression = projectExpressions.stream() - .collect(Collectors.toMap(NamedExpression::getExprId, e -> { - if (e instanceof Alias) { - return ((Alias) e).child(); - } - return e; - })); - return aggExpressions.stream().map(e -> { - if ((e instanceof NamedExpression) - && exprIdToExpression.containsKey(((NamedExpression) e).getExprId())) { - return exprIdToExpression.get(((NamedExpression) e).getExprId()); - } - return e; - }).map(e -> { - return slotToProducerOpt - .map(slotToExpressions -> ExpressionUtils.replace(e, slotToExpressions)) - .orElse(e); - }).collect(ImmutableSet.toImmutableSet()); - } - private static class OneValueSlotAggChecker extends ExpressionVisitor { public static final OneValueSlotAggChecker INSTANCE = new OneValueSlotAggChecker(); From 2a4e324fe33f3696f27e91f24b613eb2deaf418a Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Fri, 17 May 2024 14:01:08 +0800 Subject: [PATCH 5/5] update code --- .../rules/rewrite/AdjustPreAggStatus.java | 234 +++++++++++------- 1 file changed, 139 insertions(+), 95 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java index d6ddbb5bcc6401..a0c0b56dd71c99 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java @@ -81,12 +81,15 @@ public List buildRules() { .thenApplyNoThrow(ctx -> { LogicalAggregate agg = ctx.root; LogicalOlapScan scan = agg.child(); - List aggregateFunctions = - extractAggFunctionAndReplaceSlot(agg, Optional.empty()); - List groupByExpressions = agg.getGroupByExpressions(); - Set predicates = ImmutableSet.of(); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = agg.getGroupByExpressions(); + Set predicates = ImmutableSet.of(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return agg.withChildren(scan.withPreAggStatus(preAggStatus)); }).toRule(RuleType.PREAGG_STATUS_AGG_SCAN), @@ -97,13 +100,16 @@ public List buildRules() { LogicalAggregate> agg = ctx.root; LogicalFilter filter = agg.child(); LogicalOlapScan scan = filter.child(); - List aggregateFunctions = - extractAggFunctionAndReplaceSlot(agg, Optional.empty()); - List groupByExpressions = - agg.getGroupByExpressions(); - Set predicates = filter.getConjuncts(); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = + agg.getGroupByExpressions(); + Set predicates = filter.getConjuncts(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return agg.withChildren(filter .withChildren(scan.withPreAggStatus(preAggStatus))); }).toRule(RuleType.PREAGG_STATUS_AGG_FILTER_SCAN), @@ -116,15 +122,18 @@ public List buildRules() { ctx.root; LogicalProject project = agg.child(); LogicalOlapScan scan = project.child(); - List aggregateFunctions = - extractAggFunctionAndReplaceSlot(agg, - Optional.of(project)); - List groupByExpressions = - ExpressionUtils.replace(agg.getGroupByExpressions(), - project.getAliasToProducer()); - Set predicates = ImmutableSet.of(); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, + Optional.of(project)); + List groupByExpressions = + ExpressionUtils.replace(agg.getGroupByExpressions(), + project.getAliasToProducer()); + Set predicates = ImmutableSet.of(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return agg.withChildren(project .withChildren(scan.withPreAggStatus(preAggStatus))); }).toRule(RuleType.PREAGG_STATUS_AGG_PROJECT_SCAN), @@ -137,14 +146,17 @@ public List buildRules() { LogicalProject> project = agg.child(); LogicalFilter filter = project.child(); LogicalOlapScan scan = filter.child(); - List aggregateFunctions = - extractAggFunctionAndReplaceSlot(agg, Optional.of(project)); - List groupByExpressions = - ExpressionUtils.replace(agg.getGroupByExpressions(), - project.getAliasToProducer()); - Set predicates = filter.getConjuncts(); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.of(project)); + List groupByExpressions = + ExpressionUtils.replace(agg.getGroupByExpressions(), + project.getAliasToProducer()); + Set predicates = filter.getConjuncts(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return agg.withChildren(project.withChildren(filter .withChildren(scan.withPreAggStatus(preAggStatus)))); }).toRule(RuleType.PREAGG_STATUS_AGG_PROJECT_FILTER_SCAN), @@ -158,15 +170,18 @@ public List buildRules() { agg.child(); LogicalProject project = filter.child(); LogicalOlapScan scan = project.child(); - List aggregateFunctions = - extractAggFunctionAndReplaceSlot(agg, Optional.of(project)); - List groupByExpressions = - ExpressionUtils.replace(agg.getGroupByExpressions(), - project.getAliasToProducer()); - Set predicates = ExpressionUtils.replace( - filter.getConjuncts(), project.getAliasToProducer()); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.of(project)); + List groupByExpressions = + ExpressionUtils.replace(agg.getGroupByExpressions(), + project.getAliasToProducer()); + Set predicates = ExpressionUtils.replace( + filter.getConjuncts(), project.getAliasToProducer()); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return agg.withChildren(filter.withChildren(project .withChildren(scan.withPreAggStatus(preAggStatus)))); }).toRule(RuleType.PREAGG_STATUS_AGG_FILTER_PROJECT_SCAN), @@ -178,12 +193,15 @@ public List buildRules() { LogicalAggregate> agg = ctx.root; LogicalRepeat repeat = agg.child(); LogicalOlapScan scan = repeat.child(); - List aggregateFunctions = - extractAggFunctionAndReplaceSlot(agg, Optional.empty()); - List groupByExpressions = nonVirtualGroupByExprs(agg); - Set predicates = ImmutableSet.of(); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = nonVirtualGroupByExprs(agg); + Set predicates = ImmutableSet.of(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return agg.withChildren(repeat .withChildren(scan.withPreAggStatus(preAggStatus))); }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_SCAN), @@ -196,13 +214,16 @@ public List buildRules() { LogicalRepeat> repeat = agg.child(); LogicalFilter filter = repeat.child(); LogicalOlapScan scan = filter.child(); - List aggregateFunctions = - extractAggFunctionAndReplaceSlot(agg, Optional.empty()); - List groupByExpressions = - nonVirtualGroupByExprs(agg); - Set predicates = filter.getConjuncts(); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = + nonVirtualGroupByExprs(agg); + Set predicates = filter.getConjuncts(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return agg.withChildren(repeat.withChildren(filter .withChildren(scan.withPreAggStatus(preAggStatus)))); }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_FILTER_SCAN), @@ -215,14 +236,17 @@ public List buildRules() { LogicalRepeat> repeat = agg.child(); LogicalProject project = repeat.child(); LogicalOlapScan scan = project.child(); - List aggregateFunctions = - extractAggFunctionAndReplaceSlot(agg, Optional.empty()); - List groupByExpressions = - ExpressionUtils.replace(nonVirtualGroupByExprs(agg), - project.getAliasToProducer()); - Set predicates = ImmutableSet.of(); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = + ExpressionUtils.replace(nonVirtualGroupByExprs(agg), + project.getAliasToProducer()); + Set predicates = ImmutableSet.of(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return agg.withChildren(repeat.withChildren(project .withChildren(scan.withPreAggStatus(preAggStatus)))); }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_PROJECT_SCAN), @@ -237,14 +261,17 @@ public List buildRules() { LogicalProject> project = repeat.child(); LogicalFilter filter = project.child(); LogicalOlapScan scan = filter.child(); - List aggregateFunctions = - extractAggFunctionAndReplaceSlot(agg, Optional.empty()); - List groupByExpressions = - ExpressionUtils.replace(nonVirtualGroupByExprs(agg), - project.getAliasToProducer()); - Set predicates = filter.getConjuncts(); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = + ExpressionUtils.replace(nonVirtualGroupByExprs(agg), + project.getAliasToProducer()); + Set predicates = filter.getConjuncts(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return agg.withChildren(repeat .withChildren(project.withChildren(filter.withChildren( scan.withPreAggStatus(preAggStatus))))); @@ -260,15 +287,18 @@ public List buildRules() { LogicalFilter> filter = repeat.child(); LogicalProject project = filter.child(); LogicalOlapScan scan = project.child(); - List aggregateFunctions = - extractAggFunctionAndReplaceSlot(agg, Optional.of(project)); - List groupByExpressions = - ExpressionUtils.replace(nonVirtualGroupByExprs(agg), - project.getAliasToProducer()); - Set predicates = ExpressionUtils.replace( - filter.getConjuncts(), project.getAliasToProducer()); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.of(project)); + List groupByExpressions = + ExpressionUtils.replace(nonVirtualGroupByExprs(agg), + project.getAliasToProducer()); + Set predicates = ExpressionUtils.replace( + filter.getConjuncts(), project.getAliasToProducer()); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return agg.withChildren(repeat .withChildren(filter.withChildren(project.withChildren( scan.withPreAggStatus(preAggStatus))))); @@ -281,12 +311,15 @@ public List buildRules() { LogicalFilter> filter = ctx.root; LogicalProject project = filter.child(); LogicalOlapScan scan = project.child(); - List aggregateFunctions = ImmutableList.of(); - List groupByExpressions = ImmutableList.of(); - Set predicates = ExpressionUtils.replace( - filter.getConjuncts(), project.getAliasToProducer()); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = ImmutableList.of(); + List groupByExpressions = ImmutableList.of(); + Set predicates = ExpressionUtils.replace( + filter.getConjuncts(), project.getAliasToProducer()); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return filter.withChildren(project .withChildren(scan.withPreAggStatus(preAggStatus))); }).toRule(RuleType.PREAGG_STATUS_FILTER_PROJECT_SCAN), @@ -296,11 +329,14 @@ public List buildRules() { .thenApplyNoThrow(ctx -> { LogicalFilter filter = ctx.root; LogicalOlapScan scan = filter.child(); - List aggregateFunctions = ImmutableList.of(); - List groupByExpressions = ImmutableList.of(); - Set predicates = filter.getConjuncts(); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = ImmutableList.of(); + List groupByExpressions = ImmutableList.of(); + Set predicates = filter.getConjuncts(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return filter.withChildren(scan.withPreAggStatus(preAggStatus)); }).toRule(RuleType.PREAGG_STATUS_FILTER_SCAN), @@ -308,11 +344,14 @@ public List buildRules() { logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet) .thenApplyNoThrow(ctx -> { LogicalOlapScan scan = ctx.root; - List aggregateFunctions = ImmutableList.of(); - List groupByExpressions = ImmutableList.of(); - Set predicates = ImmutableSet.of(); - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, - aggregateFunctions, groupByExpressions); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = ImmutableList.of(); + List groupByExpressions = ImmutableList.of(); + Set predicates = ImmutableSet.of(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } return scan.withPreAggStatus(preAggStatus); }).toRule(RuleType.PREAGG_STATUS_SCAN)); } @@ -354,14 +393,19 @@ private List extractAggFunctionAndReplaceSlot(LogicalAggregat .collect(Collectors.toList()); } - private PreAggStatus checkPreAggStatus(LogicalOlapScan olapScan, Set predicates, - List aggregateFuncs, List groupingExprs) { + private PreAggStatus checkKeysType(LogicalOlapScan olapScan) { long selectIndexId = olapScan.getSelectedIndexId(); MaterializedIndexMeta meta = olapScan.getTable().getIndexMetaByIndexId(selectIndexId); if (meta.getKeysType() == KeysType.DUP_KEYS || (meta.getKeysType() == KeysType.UNIQUE_KEYS && olapScan.getTable().getEnableUniqueKeyMergeOnWrite())) { return PreAggStatus.on(); + } else { + return PreAggStatus.unset(); } + } + + private PreAggStatus checkPreAggStatus(LogicalOlapScan olapScan, Set predicates, + List aggregateFuncs, List groupingExprs) { Set outputSlots = olapScan.getOutputSet(); Pair, Set> splittedSlots = splitSlots(outputSlots); Set keySlots = splittedSlots.first;