diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java index 93a7af2d8e095a..5c50ca9effdc36 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java @@ -105,7 +105,8 @@ public class StatementContext implements Closeable { * indicate where the table come from. * QUERY: in query sql directly * INSERT_TARGET: the insert target table - * MTMV: mtmv itself and its related tables witch do not belong to this sql, but maybe used in rewrite by mtmv. + * MTMV: mtmv itself and its related tables witch do not belong to this sql, but + * maybe used in rewrite by mtmv. */ public enum TableFrom { QUERY, @@ -122,7 +123,8 @@ public enum TableFrom { private final Map> contextCacheMap = Maps.newLinkedHashMap(); private OriginStatement originStatement; - // NOTICE: we set the plan parsed by DorisParser to parsedStatement and if the plan is command, create a + // NOTICE: we set the plan parsed by DorisParser to parsedStatement and if the + // plan is command, create a // LogicalPlanAdapter with the logical plan in the command. private StatementBase parsedStatement; private ColumnAliasGenerator columnAliasGenerator; @@ -134,10 +136,14 @@ public enum TableFrom { private boolean hasNondeterministic = false; - // hasUnknownColStats true if any column stats in the tables used by this sql is unknown - // the algorithm to derive plan when column stats are unknown is implemented in cascading framework, not in dphyper. - // And hence, when column stats are unknown, even if the tables used by a sql is more than - // MAX_TABLE_COUNT_USE_CASCADES_JOIN_REORDER, join reorder should choose cascading framework. + // hasUnknownColStats true if any column stats in the tables used by this sql is + // unknown + // the algorithm to derive plan when column stats are unknown is implemented in + // cascading framework, not in dphyper. + // And hence, when column stats are unknown, even if the tables used by a sql is + // more than + // MAX_TABLE_COUNT_USE_CASCADES_JOIN_REORDER, join reorder should choose + // cascading framework. // Thus hasUnknownColStats has higher priority than isDpHyp private boolean hasUnknownColStats = false; @@ -160,11 +166,13 @@ public enum TableFrom { private final Set viewDdlSqlSet = Sets.newHashSet(); private final SqlCacheContext sqlCacheContext; - // generate for next id for prepared statement's placeholders, which is connection level + // generate for next id for prepared statement's placeholders, which is + // connection level private final IdGenerator placeHolderIdGenerator = PlaceholderId.createGenerator(); // relation id to placeholders for prepared statement, ordered by placeholder id private final Map idToPlaceholderRealExpr = new TreeMap<>(); - // map placeholder id to comparison slot, which will used to replace conjuncts directly + // map placeholder id to comparison slot, which will used to replace conjuncts + // directly private final Map idToComparisonSlot = new TreeMap<>(); // collect all hash join conditions to compute node connectivity in join graph @@ -173,7 +181,8 @@ public enum TableFrom { private final List hints = new ArrayList<>(); private boolean hintForcePreAggOn = false; - // the columns in Plan.getExpressions(), such as columns in join condition or filter condition, group by expression + // the columns in Plan.getExpressions(), such as columns in join condition or + // filter condition, group by expression private final Set keySlots = Sets.newHashSet(); private BitSet disableRules; @@ -217,19 +226,24 @@ public enum TableFrom { private final List insertTargetSchema = new ArrayList<>(); // for create view support in nereids - // key is the start and end position of the sql substring that needs to be replaced, + // key is the start and end position of the sql substring that needs to be + // replaced, // and value is the new string used for replacement. - private final TreeMap, String> indexInSqlToString - = new TreeMap<>(new Pair.PairComparator<>()); - // Record table id mapping, the key is the hash code of union catalogId, databaseId, tableId + private final TreeMap, String> indexInSqlToString = new TreeMap<>( + new Pair.PairComparator<>()); + // Record table id mapping, the key is the hash code of union catalogId, + // databaseId, tableId // the value is the auto-increment id in the cascades context private final Map, TableId> tableIdMapping = new LinkedHashMap<>(); - // Record the materialization statistics by id which is used for cost estimation. - // Maybe return null, which means the id according statistics should calc normally rather than getting + // Record the materialization statistics by id which is used for cost + // estimation. + // Maybe return null, which means the id according statistics should calc + // normally rather than getting // form this map private final Map relationIdToStatisticsMap = new LinkedHashMap<>(); - // Indicates the query is short-circuited in both plan and execution phase, typically + // Indicates the query is short-circuited in both plan and execution phase, + // typically // for high speed/concurrency point queries private boolean isShortCircuitQuery; @@ -251,8 +265,8 @@ public enum TableFrom { private long materializedViewRewriteDuration = 0L; // Record used table and it's used partitions - private final Multimap, Pair>> tableUsedPartitionNameMap = - HashMultimap.create(); + private final Multimap, Pair>> tableUsedPartitionNameMap = HashMultimap + .create(); private final Map relationIdToCommonTableIdMap = new HashMap<>(); // Record mtmv and valid partitions map because this is time-consuming behavior @@ -270,8 +284,11 @@ public enum TableFrom { // this record the rewritten plan by mv in RBO phase private final List rewrittenPlansByMv = new ArrayList<>(); private boolean forceRecordTmpPlan = false; - // this record the rule in PreMaterializedViewRewriter.NEED_PRE_REWRITE_RULE_TYPES if is applied successfully - // or not, if success and in PreRewriteStrategy.FOR_IN_ROB or PreRewriteStrategy.TRY_IN_ROB, mv + // this record the rule in + // PreMaterializedViewRewriter.NEED_PRE_REWRITE_RULE_TYPES if is applied + // successfully + // or not, if success and in PreRewriteStrategy.FOR_IN_ROB or + // PreRewriteStrategy.TRY_IN_ROB, mv // would be written in RBO phase private final BitSet needPreMvRewriteRuleMasks = new BitSet(RuleType.SENTINEL.ordinal()); // if needed to rewrite in RBO phase, this would be set true @@ -286,7 +303,8 @@ public enum TableFrom { private Optional>> mvRefreshPredicates = Optional.empty(); - // For Iceberg rewrite operations: store file scan tasks to be used by IcebergScanNode + // For Iceberg rewrite operations: store file scan tasks to be used by + // IcebergScanNode // TODO: better solution? private List icebergRewriteFileScanTasks = null; private boolean hasNestedColumns; @@ -312,7 +330,8 @@ public StatementContext(ConnectContext connectContext, OriginStatement originSta exprIdGenerator = ExprId.createGenerator(initialId); if (connectContext != null && connectContext.getSessionVariable() != null) { if (CacheAnalyzer.canUseSqlCache(connectContext.getSessionVariable())) { - // cannot set the queryId here because the queryId for the current query is set in the subsequent steps. + // cannot set the queryId here because the queryId for the current query is set + // in the subsequent steps. this.sqlCacheContext = new SqlCacheContext( connectContext.getCurrentUserIdentity()); if (originStatement != null) { @@ -342,7 +361,7 @@ public boolean isHintForcePreAggOn() { * cache view info to avoid view's def and sql mode changed before lock it. * * @param qualifiedViewName full qualified name of the view - * @param view view need to cache info + * @param view view need to cache info * * @return view info, first is view's def sql, second is view's sql mode */ @@ -560,8 +579,8 @@ public synchronized void invalidCache(String cacheKey) { public ColumnAliasGenerator getColumnAliasGenerator() { return columnAliasGenerator == null - ? columnAliasGenerator = new ColumnAliasGenerator() - : columnAliasGenerator; + ? columnAliasGenerator = new ColumnAliasGenerator() + : columnAliasGenerator; } public String generateColumnName() { @@ -608,6 +627,91 @@ public Map getRewrittenCteConsumer() { return rewrittenCteConsumer; } + /** + * Snapshot current CTE-related environment for temporary rewrite/optimization. + */ + public CteEnvironmentSnapshot cacheCteEnvironment() { + return new CteEnvironmentSnapshot( + copyMapOfSets(cteIdToConsumers), + copyMapOfSets(cteIdToOutputIds), + new HashMap<>(cteIdToProducerStats), + copyMapOfSets(consumerIdToFilters), + copyMapOfLists(cteIdToConsumerGroup), + new HashMap<>(rewrittenCteProducer), + new HashMap<>(rewrittenCteConsumer)); + } + + /** Restore CTE-related environment from snapshot. */ + public void restoreCteEnvironment(CteEnvironmentSnapshot snapshot) { + cteIdToConsumers.clear(); + cteIdToConsumers.putAll(snapshot.cteIdToConsumers); + + cteIdToOutputIds.clear(); + cteIdToOutputIds.putAll(snapshot.cteIdToOutputIds); + + cteIdToProducerStats.clear(); + cteIdToProducerStats.putAll(snapshot.cteIdToProducerStats); + + consumerIdToFilters.clear(); + consumerIdToFilters.putAll(snapshot.consumerIdToFilters); + + cteIdToConsumerGroup.clear(); + cteIdToConsumerGroup.putAll(snapshot.cteIdToConsumerGroup); + + rewrittenCteProducer.clear(); + rewrittenCteProducer.putAll(snapshot.rewrittenCteProducer); + + rewrittenCteConsumer.clear(); + rewrittenCteConsumer.putAll(snapshot.rewrittenCteConsumer); + } + + private static Map> copyMapOfSets(Map> source) { + Map> copied = new HashMap<>(); + for (Map.Entry> entry : source.entrySet()) { + copied.put(entry.getKey(), new HashSet<>(entry.getValue())); + } + return copied; + } + + private static Map> copyMapOfLists(Map> source) { + Map> copied = new HashMap<>(); + for (Map.Entry> entry : source.entrySet()) { + copied.put(entry.getKey(), new ArrayList<>(entry.getValue())); + } + return copied; + } + + /** Holder for cached CTE-related environment. */ + public static class CteEnvironmentSnapshot { + private final Map> cteIdToConsumers; + private final Map> cteIdToOutputIds; + private final Map cteIdToProducerStats; + private final Map> consumerIdToFilters; + private final Map, Group>>> cteIdToConsumerGroup; + private final Map rewrittenCteProducer; + private final Map rewrittenCteConsumer; + + /** + * cte related structures in StatementContext + */ + public CteEnvironmentSnapshot( + Map> cteIdToConsumers, + Map> cteIdToOutputIds, + Map cteIdToProducerStats, + Map> consumerIdToFilters, + Map, Group>>> cteIdToConsumerGroup, + Map rewrittenCteProducer, + Map rewrittenCteConsumer) { + this.cteIdToConsumers = cteIdToConsumers; + this.cteIdToOutputIds = cteIdToOutputIds; + this.cteIdToProducerStats = cteIdToProducerStats; + this.consumerIdToFilters = consumerIdToFilters; + this.cteIdToConsumerGroup = cteIdToConsumerGroup; + this.rewrittenCteProducer = rewrittenCteProducer; + this.rewrittenCteConsumer = rewrittenCteConsumer; + } + } + public void addViewDdlSql(String ddlSql) { this.viewDdlSqlSet.add(ddlSql); } @@ -656,6 +760,7 @@ public void addStatistics(Id id, Statistics statistics) { /** * get used mv hint by hint name + * * @param useMvName hint name, can either be USE_MV or NO_USE_MV * @return optional of useMvHint */ @@ -813,7 +918,7 @@ public Optional getSnapshot(TableIf tableIf) { * Obtain snapshot information of mvcc * * @param mvccTableInfo mvccTableInfo - * @param snapshot snapshot + * @param snapshot snapshot */ public void setSnapshot(MvccTableInfo mvccTableInfo, MvccSnapshot snapshot) { snapshots.put(mvccTableInfo, snapshot); @@ -1026,7 +1131,8 @@ public void setMvRefreshPredicates( /** * Set file scan tasks for Iceberg rewrite operations. - * This allows IcebergScanNode to use specific file scan tasks instead of scanning the full table. + * This allows IcebergScanNode to use specific file scan tasks instead of + * scanning the full table. */ public void setIcebergRewriteFileScanTasks(List tasks) { this.icebergRewriteFileScanTasks = tasks; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java index fd5b7a3f04ee80..453b805a3b3b08 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java @@ -19,6 +19,7 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.cost.Cost; import org.apache.doris.nereids.hint.Hint; import org.apache.doris.nereids.hint.UseCboRuleHint; @@ -28,6 +29,7 @@ import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.qe.ConnectContext; @@ -58,7 +60,8 @@ public CostBasedRewriteJob(List rewriteJobs) { @Override public void execute(JobContext jobContext) { - // checkHint.first means whether it use hint and checkHint.second means what kind of hint it used + // checkHint.first means whether it use hint and checkHint.second means what + // kind of hint it used Pair checkHint = checkRuleHint(); // this means it no_use_cbo_rule(xxx) hint if (checkHint.first && checkHint.second == null) { @@ -69,14 +72,18 @@ public void execute(JobContext jobContext) { CascadesContext applyCboRuleCtx = CascadesContext.newCurrentTreeContext(currentCtx); // execute cbo rule on one candidate Rewriter.getCteChildrenRewriter(applyCboRuleCtx, rewriteJobs).execute(); + Plan applyCboPlan = applyCboRuleCtx.getRewritePlan(); if (skipCboRuleCtx.getRewritePlan().deepEquals(applyCboRuleCtx.getRewritePlan())) { // this means rewrite do not do anything return; } + StatementContext.CteEnvironmentSnapshot cteEnvSnapshot = currentCtx.getStatementContext().cacheCteEnvironment(); // compare two candidates Optional> skipCboRuleCost = getCost(currentCtx, skipCboRuleCtx, jobContext); + currentCtx.getStatementContext().restoreCteEnvironment(cteEnvSnapshot); Optional> appliedCboRuleCost = getCost(currentCtx, applyCboRuleCtx, jobContext); + currentCtx.getStatementContext().restoreCteEnvironment(cteEnvSnapshot); // If one of them optimize failed, just return if (!skipCboRuleCost.isPresent() || !appliedCboRuleCost.isPresent()) { LOG.warn("Cbo rewrite execute failed on sql: {}, jobs are {}, plan is {}.", @@ -92,19 +99,20 @@ public void execute(JobContext jobContext) { } return; } - // If the candidate applied cbo rule is better, replace the original plan with it. + // If the candidate applied cbo rule is better, replace the original plan with + // it. if (appliedCboRuleCost.get().first.getValue() < skipCboRuleCost.get().first.getValue()) { - currentCtx.addPlanProcesses(applyCboRuleCtx.getPlanProcesses()); - currentCtx.setRewritePlan(applyCboRuleCtx.getRewritePlan()); + currentCtx.setRewritePlan(applyCboPlan); } } /** * check if we have use rule hint or no use rule hint - * return an optional object which checkHint.first means whether it use hint - * and checkHint.second means what kind of hint it used - * example, when we use *+ no_use_cbo_rule(xxx) * the optional would be (true, false) - * which means it use hint and the hint forbid this kind of rule + * return an optional object which checkHint.first means whether it use hint + * and checkHint.second means what kind of hint it used + * example, when we use *+ no_use_cbo_rule(xxx) * the optional would be (true, + * false) + * which means it use hint and the hint forbid this kind of rule */ private Pair checkRuleHint() { Pair checkResult = Pair.of(false, null); @@ -134,7 +142,8 @@ private Pair checkRuleHint() { } /** - * for these rules we need use_cbo_rule hint to enable it, otherwise it would be close by default + * for these rules we need use_cbo_rule hint to enable it, otherwise it would be + * close by default */ private static boolean checkBlackList(RuleType ruleType) { List ruleWhiteList = new ArrayList<>(Arrays.asList( diff --git a/regression-test/suites/nereids_p0/cte/costbasedrewrite_producer/costbasedrewrite_producer.groovy b/regression-test/suites/nereids_p0/cte/costbasedrewrite_producer/costbasedrewrite_producer.groovy new file mode 100644 index 00000000000000..128f35d7232756 --- /dev/null +++ b/regression-test/suites/nereids_p0/cte/costbasedrewrite_producer/costbasedrewrite_producer.groovy @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("costbasedrewrite_producer") { + sql """ + drop table if exists t1; + + create table t1(a1 int,b1 int) + properties("replication_num" = "1"); + + insert into t1 values(1,2); + + drop table if exists t2; + + create table t2(a2 int,b2 int) + properties("replication_num" = "1"); + + insert into t2 values(1,3); + """ + + sql""" + with cte1 as ( + select t1.a1, t1.b1 + from t1 + where t1.a1 > 0 and not exists (select distinct t2.b2 from t2 where t1.a1 = t2.a2 or t1.b1 = t2.a2) + ), + cte2 as ( + select * from cte1 union select * from cte1) + select * from cte2 join t1 on cte2.a1 = t1.a1; + + """ +}