diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java index 825a9f9ac17aa7..370d445049eaa5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java @@ -30,7 +30,6 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; -import org.apache.doris.nereids.rules.rewrite.mv.AbstractSelectMaterializedIndexRule.ReplaceExpressions; import org.apache.doris.nereids.rules.rewrite.mv.AbstractSelectMaterializedIndexRule.SlotContext; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Cast; @@ -53,10 +52,10 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator; import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator; -import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapHash; import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash; import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap; import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmapWithCheck; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.trees.plans.Plan; @@ -642,6 +641,19 @@ public List buildRules() { ); } + private static LogicalOlapScan createLogicalOlapScan(LogicalOlapScan scan, SelectResult result) { + LogicalOlapScan mvPlan; + if (result.preAggStatus.isOff()) { + // we only set preAggStatus and make index unselected to let SelectMaterializedIndexWithoutAggregate + // have a chance to run and select proper index + mvPlan = scan.withPreAggStatus(result.preAggStatus); + } else { + mvPlan = + scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId); + } + return mvPlan; + } + /////////////////////////////////////////////////////////////////////////// // Main entrance of select materialized index. /////////////////////////////////////////////////////////////////////////// @@ -745,19 +757,6 @@ public SelectResult(PreAggStatus preAggStatus, long indexId, ExprRewriteMap expr } } - private static LogicalOlapScan createLogicalOlapScan(LogicalOlapScan scan, SelectResult result) { - LogicalOlapScan mvPlan; - if (result.preAggStatus.isOff()) { - // we only set preAggStatus and make index unselected to let SelectMaterializedIndexWithoutAggregate - // have a chance to run and select proper index - mvPlan = scan.withPreAggStatus(result.preAggStatus); - } else { - mvPlan = - scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId); - } - return mvPlan; - } - /** * Do aggregate function extraction and replace aggregate function's input slots by underlying project. *

@@ -971,27 +970,24 @@ public CheckContext(LogicalOlapScan scan, long indexId) { Supplier> supplier = () -> Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER); // map> - Map> baseNameToColumnGroupingByIsKey = - scan.getTable().getSchemaByIndexId(indexId).stream() - .collect( - Collectors.groupingBy(Column::isKey, - Collectors.toMap( - c -> isBaseIndex ? c.getName() - : normalizeName(parseMvColumnToSql( - c.getName())), - Function.identity(), (v1, v2) -> v1, - supplier))); - Map> mvNameToColumnGroupingByIsKey = - scan.getTable().getSchemaByIndexId(indexId).stream() - .collect(Collectors.groupingBy(Column::isKey, - Collectors.toMap( - c -> isBaseIndex ? c.getName() - : normalizeName(parseMvColumnToMvName( - c.getNameWithoutMvPrefix(), - c.isAggregated() ? Optional.of( - c.getAggregationType().name()) - : Optional.empty())), - Function.identity(), (v1, v2) -> v1, supplier))); + Map> baseNameToColumnGroupingByIsKey = scan.getTable() + .getSchemaByIndexId(indexId).stream() + .collect(Collectors.groupingBy(Column::isKey, + Collectors.toMap( + c -> isBaseIndex ? c.getName() + : normalizeName(parseMvColumnToSql(c.getName())), + Function.identity(), (v1, v2) -> v1, supplier))); + Map> mvNameToColumnGroupingByIsKey = scan.getTable() + .getSchemaByIndexId(indexId).stream() + .collect(Collectors.groupingBy(Column::isKey, + Collectors.toMap( + c -> isBaseIndex ? c.getName() + : normalizeName(parseMvColumnToMvName( + c.getNameWithoutMvPrefix(), + c.isAggregated() + ? Optional.of(c.getAggregationType().name()) + : Optional.empty())), + Function.identity(), (v1, v2) -> v1, supplier))); this.keyNameToColumn = mvNameToColumnGroupingByIsKey.getOrDefault(true, Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER)); @@ -1033,11 +1029,13 @@ private PreAggStatus checkGroupingExprs( /** * Predicates should not have value type columns. */ - private PreAggStatus checkPredicates( - List predicates, - CheckContext checkContext) { - return disablePreAggIfContainsAnyValueColumn(predicates, checkContext, - "Predicate %s contains value column %s"); + private PreAggStatus checkPredicates(List predicates, CheckContext checkContext) { + Set indexConjuncts = PlanNode + .splitAndCompoundPredicateToConjuncts(checkContext.getMeta().getWhereClause()).stream() + .map(e -> new NereidsParser().parseExpression(e.toSql()).toSql()).collect(Collectors.toSet()); + return disablePreAggIfContainsAnyValueColumn( + predicates.stream().filter(e -> !indexConjuncts.contains(e.toSql())).collect(Collectors.toList()), + checkContext, "Predicate %s contains value column %s"); } /** @@ -1075,24 +1073,26 @@ private AggRewriteResult rewriteAgg(MaterializedIndex index, // has rewritten agg functions Map slotMap = exprRewriteMap.slotMap; - if (!slotMap.isEmpty()) { - // Note that the slots in the rewritten agg functions shouldn't appear in filters or grouping expressions. - // For example: we have a duplicated-type table t(c1, c2) and a materialized index that has - // a bitmap_union column `mv_bitmap_union_c2` for the column c2. - // The query `select c1, count(distinct c2) from t where c2 > 0 group by c1` can't use the materialized - // index because we have a filter `c2 > 0` for the aggregated column c2. - Set slotsToReplace = slotMap.keySet(); - Set indexConjuncts = PlanNode + // Note that the slots in the rewritten agg functions shouldn't appear in filters or grouping expressions. + // For example: we have a duplicated-type table t(c1, c2) and a materialized index that has + // a bitmap_union column `mv_bitmap_union_c2` for the column c2. + // The query `select c1, count(distinct c2) from t where c2 > 0 group by c1` can't use the materialized + // index because we have a filter `c2 > 0` for the aggregated column c2. + Set slotsToReplace = slotMap.keySet(); + Set indexConjuncts; + try { + indexConjuncts = PlanNode .splitAndCompoundPredicateToConjuncts(context.checkContext.getMeta().getWhereClause()).stream() .map(e -> new NereidsParser().parseExpression(e.toSql()).toSql()).collect(Collectors.toSet()); - if (isInputSlotsContainsNone( - predicates.stream().filter(e -> !indexConjuncts.contains(e.toSql())).collect(Collectors.toList()), - slotsToReplace) && isInputSlotsContainsNone(groupingExprs, slotsToReplace)) { - ImmutableSet newRequiredSlots = requiredScanOutput.stream() - .map(slot -> (Slot) ExpressionUtils.replace(slot, slotMap)) - .collect(ImmutableSet.toImmutableSet()); - return new AggRewriteResult(index, true, newRequiredSlots, exprRewriteMap); - } + } catch (Exception e) { + return new AggRewriteResult(index, false, null, null); + } + if (isInputSlotsContainsNone( + predicates.stream().filter(e -> !indexConjuncts.contains(e.toSql())).collect(Collectors.toList()), + slotsToReplace) && isInputSlotsContainsNone(groupingExprs, slotsToReplace)) { + ImmutableSet newRequiredSlots = requiredScanOutput.stream() + .map(slot -> (Slot) ExpressionUtils.replace(slot, slotMap)).collect(ImmutableSet.toImmutableSet()); + return new AggRewriteResult(index, true, newRequiredSlots, exprRewriteMap); } return new AggRewriteResult(index, false, null, null); @@ -1210,8 +1210,7 @@ public Expression visitCount(Count count, RewriteContext context) { Expression expr = new ToBitmapWithCheck(castIfNeed(count.child(0), BigIntType.INSTANCE)); // count distinct a value column. - if (slotOpt.isPresent() && !context.checkContext.keyNameToColumn.containsKey( - normalizeName(expr.toSql()))) { + if (slotOpt.isPresent()) { String bitmapUnionColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder( AggregateType.BITMAP_UNION, CreateMaterializedViewStmt.mvColumnBuilder(expr.toSql()))); @@ -1232,38 +1231,102 @@ public Expression visitCount(Count count, RewriteContext context) { return bitmapUnionCount; } } - } else if (!count.isDistinct() && count.arity() == 1) { - // count(col) -> sum(mva_SUM__CASE WHEN col IS NULL THEN 0 ELSE 1 END) + } + Expression child = null; + if (!count.isDistinct() && count.arity() == 1) { + // count(col) -> sum(mva_SUM__CASE WHEN col IS NULL THEN 0 ELSE 1 END) Optional slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(count.child(0)); - // count a value column. - if (slotOpt.isPresent() && !context.checkContext.keyNameToColumn.containsKey( - normalizeName(slotOpt.get().toSql()))) { - String countColumn = normalizeName(CreateMaterializedViewStmt - .mvColumnBuilder(AggregateType.SUM, - CreateMaterializedViewStmt.mvColumnBuilder(slotToCaseWhen(slotOpt.get()).toSql()))); - - Column mvColumn = context.checkContext.getColumn(countColumn); - // has bitmap_union_count column - if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) { - Slot countSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index) - .stream() - .filter(s -> countColumn.equalsIgnoreCase(normalizeName(s.getName()))) - .findFirst() - .orElseThrow(() -> new AnalysisException( - "cannot find count slot when select mv")); + if (slotOpt.isPresent()) { + child = slotOpt.get(); + } + } else if (count.arity() == 0) { + // count(*) / count(1) -> sum(mva_SUM__CASE WHEN 1 IS NULL THEN 0 ELSE 1 END) + child = new TinyIntLiteral((byte) 1); + } + + if (child != null) { + String countColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(AggregateType.SUM, + CreateMaterializedViewStmt.mvColumnBuilder(slotToCaseWhen(child).toSql()))); + + Column mvColumn = context.checkContext.getColumn(countColumn); + if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) { + Slot countSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index).stream() + .filter(s -> countColumn.equalsIgnoreCase(normalizeName(s.getName()))).findFirst() + .orElseThrow(() -> new AnalysisException("cannot find count slot when select mv")); - context.exprRewriteMap.slotMap.put(slotOpt.get(), countSlot); - context.exprRewriteMap.projectExprMap.put(slotOpt.get(), countSlot); - Sum sum = new Sum(countSlot); - context.exprRewriteMap.aggFuncMap.put(count, sum); - return sum; + if (child instanceof Slot) { + context.exprRewriteMap.slotMap.put((Slot) child, countSlot); } + context.exprRewriteMap.projectExprMap.put(child, countSlot); + Sum sum = new Sum(countSlot); + context.exprRewriteMap.aggFuncMap.put(count, sum); + return sum; } } return count; } + /** + * bitmap_union(to_bitmap(col)) -> + * bitmap_union(mva_BITMAP_UNION__to_bitmap_with_check(col)) + */ + @Override + public Expression visitBitmapUnion(BitmapUnion bitmapUnion, RewriteContext context) { + Expression result = visitAggregateFunction(bitmapUnion, context); + if (result != bitmapUnion) { + return result; + } + if (bitmapUnion.child() instanceof ToBitmap) { + ToBitmap toBitmap = (ToBitmap) bitmapUnion.child(); + Optional slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(toBitmap.child()); + if (slotOpt.isPresent()) { + String bitmapUnionColumn = normalizeName(CreateMaterializedViewStmt + .mvColumnBuilder(AggregateType.BITMAP_UNION, CreateMaterializedViewStmt + .mvColumnBuilder(new ToBitmapWithCheck(toBitmap.child()).toSql()))); + + Column mvColumn = context.checkContext.getColumn(bitmapUnionColumn); + // has bitmap_union column + if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) { + + Slot bitmapUnionSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index) + .stream().filter(s -> bitmapUnionColumn.equalsIgnoreCase(normalizeName(s.getName()))) + .findFirst().orElseThrow( + () -> new AnalysisException("cannot find bitmap union slot when select mv")); + + context.exprRewriteMap.slotMap.put(slotOpt.get(), bitmapUnionSlot); + context.exprRewriteMap.projectExprMap.put(toBitmap, bitmapUnionSlot); + BitmapUnion newBitmapUnion = new BitmapUnion(bitmapUnionSlot); + context.exprRewriteMap.aggFuncMap.put(bitmapUnion, newBitmapUnion); + return newBitmapUnion; + } + } + } else { + Expression child = bitmapUnion.child(); + String bitmapUnionColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder( + AggregateType.BITMAP_UNION, CreateMaterializedViewStmt.mvColumnBuilder(child.toSql()))); + + Column mvColumn = context.checkContext.getColumn(bitmapUnionColumn); + // has bitmap_union column + if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) { + + Slot bitmapUnionSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index) + .stream().filter(s -> bitmapUnionColumn.equalsIgnoreCase(normalizeName(s.getName()))) + .findFirst() + .orElseThrow(() -> new AnalysisException("cannot find bitmap union slot when select mv")); + if (child instanceof Slot) { + context.exprRewriteMap.slotMap.put((Slot) child, bitmapUnionSlot); + } + context.exprRewriteMap.projectExprMap.put(child, bitmapUnionSlot); + BitmapUnion newBitmapUnion = new BitmapUnion(bitmapUnionSlot); + context.exprRewriteMap.aggFuncMap.put(bitmapUnion, newBitmapUnion); + return newBitmapUnion; + } + } + + return bitmapUnion; + } + /** * bitmap_union_count(to_bitmap(col)) -> bitmap_union_count(mva_BITMAP_UNION__to_bitmap_with_check(col)) */ @@ -1300,32 +1363,26 @@ public Expression visitBitmapUnionCount(BitmapUnionCount bitmapUnionCount, Rewri return newBitmapUnionCount; } } - } else if (bitmapUnionCount.child() instanceof BitmapHash) { - BitmapHash bitmapHash = (BitmapHash) bitmapUnionCount.child(); - Optional slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(bitmapHash.child()); - if (slotOpt.isPresent()) { - String bitmapUnionCountColumn = normalizeName( - CreateMaterializedViewStmt.mvColumnBuilder(AggregateType.BITMAP_UNION, - CreateMaterializedViewStmt.mvColumnBuilder(bitmapHash.toSql()))); - - Column mvColumn = context.checkContext.getColumn(bitmapUnionCountColumn); - // has bitmap_union_count column - if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) { + } else { + Expression child = bitmapUnionCount.child(); + String bitmapUnionCountColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder( + AggregateType.BITMAP_UNION, CreateMaterializedViewStmt.mvColumnBuilder(child.toSql()))); - Slot bitmapUnionCountSlot = context.checkContext.scan - .getOutputByIndex(context.checkContext.index) - .stream() - .filter(s -> bitmapUnionCountColumn.equalsIgnoreCase(normalizeName(s.getName()))) - .findFirst() - .orElseThrow(() -> new AnalysisException( - "cannot find bitmap union count slot when select mv")); + Column mvColumn = context.checkContext.getColumn(bitmapUnionCountColumn); + // has bitmap_union_count column + if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) { - context.exprRewriteMap.slotMap.put(slotOpt.get(), bitmapUnionCountSlot); - context.exprRewriteMap.projectExprMap.put(bitmapHash, bitmapUnionCountSlot); - BitmapUnionCount newBitmapUnionCount = new BitmapUnionCount(bitmapUnionCountSlot); - context.exprRewriteMap.aggFuncMap.put(bitmapUnionCount, newBitmapUnionCount); - return newBitmapUnionCount; + Slot bitmapUnionCountSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index) + .stream().filter(s -> bitmapUnionCountColumn.equalsIgnoreCase(normalizeName(s.getName()))) + .findFirst().orElseThrow( + () -> new AnalysisException("cannot find bitmap union count slot when select mv")); + if (child instanceof Slot) { + context.exprRewriteMap.slotMap.put((Slot) child, bitmapUnionCountSlot); } + context.exprRewriteMap.projectExprMap.put(child, bitmapUnionCountSlot); + BitmapUnionCount newBitmapUnionCount = new BitmapUnionCount(bitmapUnionCountSlot); + context.exprRewriteMap.aggFuncMap.put(bitmapUnionCount, newBitmapUnionCount); + return newBitmapUnionCount; } } @@ -1419,8 +1476,7 @@ public Expression visitNdv(Ndv ndv, RewriteContext context) { } Optional slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(ndv.child(0)); // ndv on a value column. - if (slotOpt.isPresent() && !context.checkContext.keyNameToColumn.containsKey( - normalizeName(slotOpt.get().toSql()))) { + if (slotOpt.isPresent()) { Expression expr = castIfNeed(ndv.child(), VarcharType.SYSTEM_DEFAULT); String hllUnionColumn = normalizeName( CreateMaterializedViewStmt.mvColumnBuilder(AggregateType.HLL_UNION, @@ -1453,8 +1509,7 @@ public Expression visitSum(Sum sum, RewriteContext context) { return result; } Optional slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(sum.child(0)); - if (!sum.isDistinct() && slotOpt.isPresent() - && !context.checkContext.keyNameToColumn.containsKey(normalizeName(slotOpt.get().toSql()))) { + if (!sum.isDistinct() && slotOpt.isPresent()) { Expression expr = castIfNeed(sum.child(), BigIntType.INSTANCE); String sumColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(AggregateType.SUM, CreateMaterializedViewStmt.mvColumnBuilder(expr.toSql()))); @@ -1490,10 +1545,8 @@ public Expression visitAggregateFunction(AggregateFunction aggregateFunction, Re Set slots = aggregateFunction.collect(SlotReference.class::isInstance); for (Slot slot : slots) { - if (!context.checkContext.keyNameToColumn.containsKey(normalizeName(slot.toSql()))) { - context.exprRewriteMap.slotMap.put(slot, aggStateSlot); - context.exprRewriteMap.projectExprMap.put(slot, aggStateSlot); - } + context.exprRewriteMap.slotMap.put(slot, aggStateSlot); + context.exprRewriteMap.projectExprMap.put(slot, aggStateSlot); } MergeCombinator mergeCombinator = new MergeCombinator(Arrays.asList(aggStateSlot), aggregateFunction); diff --git a/regression-test/data/mv_p0/ut/testBitmapUnionInQuery/testBitmapUnionInQuery.out b/regression-test/data/mv_p0/ut/testBitmapUnionInQuery/testBitmapUnionInQuery.out index 88913c5b65d879..75a47cb33192a5 100644 --- a/regression-test/data/mv_p0/ut/testBitmapUnionInQuery/testBitmapUnionInQuery.out +++ b/regression-test/data/mv_p0/ut/testBitmapUnionInQuery/testBitmapUnionInQuery.out @@ -7,3 +7,6 @@ -- !select_mv -- 1 2 +-- !select_mv -- +1 2 + diff --git a/regression-test/suites/mv_p0/ut/testBitmapUnionInQuery/testBitmapUnionInQuery.groovy b/regression-test/suites/mv_p0/ut/testBitmapUnionInQuery/testBitmapUnionInQuery.groovy index 798d350b92cd7b..d4502dd03b19b2 100644 --- a/regression-test/suites/mv_p0/ut/testBitmapUnionInQuery/testBitmapUnionInQuery.groovy +++ b/regression-test/suites/mv_p0/ut/testBitmapUnionInQuery/testBitmapUnionInQuery.groovy @@ -47,4 +47,10 @@ suite ("testBitmapUnionInQuery") { contains "(user_tags_mv)" } qt_select_mv "select user_id, bitmap_union_count(to_bitmap(tag_id)) a from user_tags group by user_id having a>1 order by a;" + + explain { + sql("select user_id, bitmap_count(bitmap_union(to_bitmap(tag_id))) a from user_tags group by user_id having a>1 order by a;") + contains "(user_tags_mv)" + } + qt_select_mv "select user_id, bitmap_count(bitmap_union(to_bitmap(tag_id))) a from user_tags group by user_id having a>1 order by a;" }