Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
Expand Down Expand Up @@ -111,42 +112,66 @@ public List<Rule> buildRules() {
logicalAggregate(
logicalFilter(
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
).when(filter -> !filter.getConjuncts().isEmpty()))
.when(agg -> enablePushDownCountOnIndex())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> f instanceof Count && !f.isDistinct() && (((Count) f).isCountStar()
|| f.child(0) instanceof Slot));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
LogicalFilter<LogicalOlapScan> filter = agg.child();
LogicalOlapScan olapScan = filter.child();
return pushdownCountOnIndex(agg, null, filter, olapScan, ctx.cascadesContext);
})
)
)
.when(agg -> enablePushDownCountOnIndex())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
if (funcs.isEmpty() || !funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot))) {
return false;
}
Set<Expression> conjuncts = agg.child().getConjuncts();
if (conjuncts.isEmpty()) {
return false;
}

Set<Slot> aggSlots = funcs.stream()
.flatMap(f -> f.getInputSlots().stream())
.collect(Collectors.toSet());
return conjuncts.stream().allMatch(expr -> checkSlotInOrExpression(expr, aggSlots));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
LogicalFilter<LogicalOlapScan> filter = agg.child();
LogicalOlapScan olapScan = filter.child();
return pushdownCountOnIndex(agg, null, filter, olapScan, ctx.cascadesContext);
})
),
RuleType.COUNT_ON_INDEX.build(
logicalAggregate(
logicalProject(
logicalFilter(
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
).when(filter -> !filter.getConjuncts().isEmpty())))
.when(agg -> enablePushDownCountOnIndex())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
LogicalFilter<LogicalOlapScan> filter = project.child();
LogicalOlapScan olapScan = filter.child();
return pushdownCountOnIndex(agg, project, filter, olapScan, ctx.cascadesContext);
})
)
)
)
.when(agg -> enablePushDownCountOnIndex())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
if (funcs.isEmpty() || !funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot))) {
return false;
}
Set<Expression> conjuncts = agg.child().child().getConjuncts();
if (conjuncts.isEmpty()) {
return false;
}

Set<Slot> aggSlots = funcs.stream()
.flatMap(f -> f.getInputSlots().stream())
.collect(Collectors.toSet());
return conjuncts.stream().allMatch(expr -> checkSlotInOrExpression(expr, aggSlots));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
LogicalFilter<LogicalOlapScan> filter = project.child();
LogicalOlapScan olapScan = filter.child();
return pushdownCountOnIndex(agg, project, filter, olapScan, ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE_WITHOUT_PROJECT.build(
logicalAggregate(
Expand Down Expand Up @@ -451,6 +476,22 @@ private boolean enablePushDownCountOnIndex() {
return connectContext != null && connectContext.getSessionVariable().isEnablePushDownCountOnIndex();
}

private boolean checkSlotInOrExpression(Expression expr, Set<Slot> aggSlots) {
if (expr instanceof Or) {
Set<Slot> slots = expr.getInputSlots();
if (!slots.stream().allMatch(aggSlots::contains)) {
return false;
}
} else {
for (Expression child : expr.children()) {
if (!checkSlotInOrExpression(child, aggSlots)) {
return false;
}
}
}
return true;
}

private boolean isDupOrMowKeyTable(LogicalOlapScan logicalScan) {
if (logicalScan != null) {
KeysType keysType = logicalScan.getTable().getKeysType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,12 @@
-- !sql --
3

-- !sql --
1

-- !sql --
1

-- !sql --
1

Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,35 @@ suite("test_count_on_index_2", "p0"){
qt_sql """ select count() from ${indexTbName3} where (a >= 10 and a < 20) and (b >= 5 and b < 14) and (c >= 16 and c < 25); """
qt_sql """ select count() from ${indexTbName3} where (a >= 10 and a < 20) and (b >= 5 and b < 16) and (c >= 13 and c < 25); """

sql """ DROP TABLE IF EXISTS tt """
sql """
CREATE TABLE `tt` (
`a` int NULL,
`b` int NULL,
`c` int NULL,
INDEX col_c (`b`) USING INVERTED,
INDEX col_b (`c`) USING INVERTED
) ENGINE=OLAP
DUPLICATE KEY(`a`)
COMMENT 'OLAP'
DISTRIBUTED BY RANDOM BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);
"""

sql """ insert into tt values (20, 23, 30); """
sql """ insert into tt values (20, null, 30); """
qt_sql """ select count(b) from tt where b = 23 or c = 30; """
qt_sql """ select count(b) from tt where b = 23 and (c = 20 or c = 30); """
explain {
sql("select count(b) from tt where b = 23 and (c = 20 or c = 30);")
contains "COUNT_ON_INDEX"
}
explain {
sql("select count(b) from tt where b = 23 or b = 30;")
contains "COUNT_ON_INDEX"
}
} finally {
//try_sql("DROP TABLE IF EXISTS ${testTable}")
}
Expand Down