Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.ListQuery;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
import org.apache.doris.nereids.trees.expressions.Slot;
Expand All @@ -38,7 +40,6 @@
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
Expand Down Expand Up @@ -73,7 +74,8 @@ public List<Rule> buildRules() {
RuleType.FILTER_SUBQUERY_TO_APPLY.build(
logicalFilter().thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;

boolean shouldOutputMarkJoinSlot = filter.getConjuncts().stream()
.anyMatch(expr -> shouldOutputMarkJoinSlot(expr, SearchState.SearchNot));
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = filter.getConjuncts().stream()
.<Set<SubqueryExpr>>map(e -> e.collect(SubqueryToApply::canConvertToSupply))
.collect(ImmutableList.toImmutableList());
Expand All @@ -98,7 +100,7 @@ public List<Rule> buildRules() {
// first step: Replace the subquery of predicate in LogicalFilter
// second step: Replace subquery with LogicalApply
ReplaceSubquery replaceSubquery = new ReplaceSubquery(
ctx.statementContext, false);
ctx.statementContext, shouldOutputMarkJoinSlot);
SubqueryContext context = new SubqueryContext(subqueryExprs);
Expression conjunct = replaceSubquery.replace(oldConjuncts.get(i), context);

Expand Down Expand Up @@ -343,7 +345,7 @@ private boolean nonMarkJoinExistsWithAgg(SubqueryExpr exists,
&& hasTopLevelAggWithoutGroupBy(exists.getQueryPlan());
}

private boolean hasTopLevelAggWithoutGroupBy(Plan plan) {
private static boolean hasTopLevelAggWithoutGroupBy(Plan plan) {
if (plan instanceof LogicalAggregate) {
return ((LogicalAggregate) plan).getGroupByExpressions().isEmpty();
} else if (plan instanceof LogicalProject || plan instanceof LogicalSort) {
Expand Down Expand Up @@ -427,7 +429,7 @@ public Expression visitExistsSubquery(Exists exists, SubqueryContext context) {
// it will always consider the returned result to be true
boolean needCreateMarkJoinSlot = isMarkJoin || shouldOutputMarkJoinSlot;
MarkJoinSlotReference markJoinSlotReference = null;
if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && needCreateMarkJoinSlot) {
if (hasTopLevelAggWithoutGroupBy(exists.getQueryPlan()) && needCreateMarkJoinSlot) {
markJoinSlotReference =
new MarkJoinSlotReference(statementContext.generateColumnName(), true);
} else if (needCreateMarkJoinSlot) {
Expand Down Expand Up @@ -505,4 +507,27 @@ private void setSubqueryToMarkJoinSlot(SubqueryExpr subquery,

}

private enum SearchState {
SearchNot,
SearchAnd,
SearchExistsOrInSubquery
}

private boolean shouldOutputMarkJoinSlot(Expression expr, SearchState searchState) {
if (searchState == SearchState.SearchNot && expr instanceof Not) {
if (shouldOutputMarkJoinSlot(((Not) expr).child(), SearchState.SearchAnd)) {
return true;
}
} else if (searchState == SearchState.SearchAnd && expr instanceof And) {
for (Expression child : expr.children()) {
if (shouldOutputMarkJoinSlot(child, SearchState.SearchExistsOrInSubquery)) {
return true;
}
}
} else if (searchState == SearchState.SearchExistsOrInSubquery
&& (expr instanceof InSubquery || expr instanceof Exists)) {
return true;
}
return false;
}
}
4 changes: 4 additions & 0 deletions regression-test/data/nereids_p0/subquery/test_subquery.out
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ true 15 1992 3021 11011920 0.000 true 9999-12-12 2015-04-02T00:00 3.141592653 2
-- !sql_mark_join --
1

-- !select_sub --
1 9
2 \N

93 changes: 93 additions & 0 deletions regression-test/suites/nereids_p0/subquery/test_subquery.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,97 @@ suite("test_subquery") {
qt_sql_mark_join """with A as (select count(*) n1 from test_one_row_relation where exists (select 1 from test_one_row_relation t where t.user_id = test_one_row_relation.user_id) or 1 = 1) select * from A;"""

sql """drop table if exists test_one_row_relation;"""

sql """drop table if exists subquery_test_t1;"""
sql """drop table if exists subquery_test_t2;"""
sql """create table subquery_test_t1 (
id int
)
UNIQUE KEY (`id`)
DISTRIBUTED BY HASH(`id`) BUCKETS 1
PROPERTIES ("replication_allocation" = "tag.location.default: 1");"""
sql """create table subquery_test_t2 (
id int
)
UNIQUE KEY (`id`)
DISTRIBUTED BY HASH(`id`) BUCKETS 1
PROPERTIES ("replication_allocation" = "tag.location.default: 1");"""

explain {
sql("""analyzed plan select subquery_test_t1.id from subquery_test_t1
where
not (
exists(select 1 from subquery_test_t2 where subquery_test_t1.id = subquery_test_t2.id and subquery_test_t2.id = 5)
and
exists(select 1 from subquery_test_t2 where subquery_test_t1.id = subquery_test_t2.id and subquery_test_t2.id = 6)
); """)
contains("isMarkJoin=true")
}
explain {
sql("""analyzed plan select subquery_test_t1.id from subquery_test_t1
where
not (
subquery_test_t1.id > 10
and
exists(select 1 from subquery_test_t2 where subquery_test_t1.id = subquery_test_t2.id and subquery_test_t2.id = 6)
);""")
contains("isMarkJoin=true")
}
explain {
sql("""analyzed plan select subquery_test_t1.id from subquery_test_t1
where
not (
subquery_test_t1.id > 10
and
subquery_test_t1.id in (select 1 from subquery_test_t2 where subquery_test_t1.id = subquery_test_t2.id and subquery_test_t2.id = 6)
); """)
contains("isMarkJoin=true")
}
explain {
sql("""analyzed plan select subquery_test_t1.id from subquery_test_t1
where
not (
subquery_test_t1.id > 10
and
subquery_test_t1.id in (select 1 from subquery_test_t2 where subquery_test_t1.id = subquery_test_t2.id and subquery_test_t2.id = 6)
); """)
contains("isMarkJoin=true")
}
explain {
sql("""analyzed plan select subquery_test_t1.id from subquery_test_t1
where
not (
subquery_test_t1.id > 10
and
( subquery_test_t1.id < 100 or subquery_test_t1.id in (select 1 from subquery_test_t2 where subquery_test_t1.id = subquery_test_t2.id and subquery_test_t2.id = 6) )
); """)
contains("isMarkJoin=true")
}
explain {
sql("""analyzed plan select subquery_test_t1.id from subquery_test_t1
where
not (
subquery_test_t1.id > 10
and
( subquery_test_t1.id < 100 or case when subquery_test_t1.id in (select 1 from subquery_test_t2 where subquery_test_t1.id = subquery_test_t2.id and subquery_test_t2.id = 6) then 1 else 0 end )
);""")
contains("isMarkJoin=true")
}

sql """drop table if exists table_23_undef_undef"""
sql """create table table_23_undef_undef (`pk` int,`col_int_undef_signed` int ,`col_varchar_10__undef_signed` varchar(10) ,`col_varchar_1024__undef_signed` varchar(1024) ) engine=olap distributed by hash(pk) buckets 10 properties( 'replication_num' = '1');"""
sql """drop table if exists table_20_undef_undef"""
sql """create table table_20_undef_undef (`pk` int,`col_int_undef_signed` int ,`col_varchar_10__undef_signed` varchar(10) ,`col_varchar_1024__undef_signed` varchar(1024) ) engine=olap distributed by hash(pk) buckets 10 properties( 'replication_num' = '1');"""
sql """drop table if exists table_9_undef_undef"""
sql """create table table_9_undef_undef (`pk` int,`col_int_undef_signed` int ,`col_varchar_10__undef_signed` varchar(10) ,`col_varchar_1024__undef_signed` varchar(1024) ) engine=olap distributed by hash(pk) buckets 10 properties( 'replication_num' = '1');"""

sql """insert into table_23_undef_undef values (0,0,'t','p'),(1,6,'q',"really"),(2,3,'p',"of"),(3,null,"he",'k'),(4,8,"this","don't"),(5,6,"see","this"),(6,5,'s','q'),(7,null,'o','j'),(8,9,'l',"could"),(9,null,"one",'l'),(10,7,"can't",'f'),(11,2,"going","not"),(12,null,'g','r'),(13,3,"ok",'s'),(14,6,"she",'k'),(15,null,"she",'p'),(16,8,"what","him"),(17,null,"from","to"),(18,5,"so","up"),(19,null,"my","is"),(20,null,'h',"see"),(21,null,"as","to"),(22,0,"know","the");"""
sql """insert into table_20_undef_undef values (0,null,'r','x'),(1,null,'m',"say"),(2,2,"mean",'h'),(3,null,'n','b'),(4,8,"do","do"),(5,9,'h',"were"),(6,null,"was","one"),(7,2,'o',"she"),(8,0,"who","me"),(9,null,'n',"that"),(10,null,"will",'l'),(11,4,'m',"if"),(12,5,"the","got"),(13,null,"why",'f'),(14,0,"of","for"),(15,null,"or","ok"),(16,null,'c','u'),(17,3,'f','c'),(18,null,"see",'f'),(19,2,'f','z');"""
sql """insert into table_9_undef_undef values (0,3,"his",'g'),(1,8,'p','n'),(2,null,"get","got"),(3,3,'r','r'),(4,null,"or","get"),(5,0,'j',"yeah"),(6,null,'w','x'),(7,8,'q',"for"),(8,3,'p',"that");"""

qt_select_sub"""SELECT DISTINCT alias1.`pk` AS field1, alias2.`col_int_undef_signed` AS field2 FROM table_23_undef_undef AS alias1, table_20_undef_undef AS alias2 WHERE ( EXISTS ( SELECT DISTINCT SQ1_alias1.`col_varchar_10__undef_signed` AS SQ1_field1 FROM table_9_undef_undef AS SQ1_alias1 WHERE SQ1_alias1.`col_varchar_10__undef_signed` = alias1.`col_varchar_10__undef_signed` ) ) OR alias1.`col_varchar_1024__undef_signed` = "TmxRwcNZHC" AND ( alias1.`col_varchar_10__undef_signed` <> "rnZeukOcuM" AND alias2.`col_varchar_10__undef_signed` != "dbPAEpzstk" ) ORDER BY alias1.`pk`, field1, field2 LIMIT 2 OFFSET 7; """
sql """drop table if exists table_23_undef_undef"""
sql """drop table if exists table_20_undef_undef"""
sql """drop table if exists table_9_undef_undef"""

}