diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala new file mode 100644 index 0000000000000..1b1b526e78140 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.normalizer + +import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE} +import org.apache.spark.sql.catalyst.rules.Rule + +object NormalizeCTEIds extends Rule[LogicalPlan]{ + override def apply(plan: LogicalPlan): LogicalPlan = { + val curId = new java.util.concurrent.atomic.AtomicLong() + plan transformDown { + + case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) => + ctas.copy(plan = apply(plan)) + + case withCTE @ WithCTE(plan, cteDefs) => + val defIdToNewId = withCTE.cteDefs.map(_.id).map((_, curId.getAndIncrement())).toMap + val normalizedPlan = canonicalizeCTE(plan, defIdToNewId) + val newCteDefs = cteDefs.map { cteDef => + val normalizedCteDef = canonicalizeCTE(cteDef.child, defIdToNewId) + cteDef.copy(child = normalizedCteDef, id = defIdToNewId(cteDef.id)) + } + withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs) + } + } + + def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: Map[Long, Long]): LogicalPlan = { + plan.transformDownWithSubqueries { + // For nested WithCTE, if defIndex didn't contain the cteId, + // means it's not current WithCTE's ref. + case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) => + ref.copy(cteId = defIdToNewId(ref.cteId)) + case unionLoop: UnionLoop if defIdToNewId.contains(unionLoop.id) => + unionLoop.copy(id = defIdToNewId(unionLoop.id)) + case unionLoopRef: UnionLoopRef if defIdToNewId.contains(unionLoopRef.loopId) => + unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 72274ee9bf174..fab64d771093f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -1742,7 +1742,8 @@ case class CacheTableAsSelect( isLazy: Boolean, options: Map[String, String], isAnalyzed: Boolean = false, - referredTempFunctions: Seq[String] = Seq.empty) extends AnalysisOnlyCommand { + referredTempFunctions: Seq[String] = Seq.empty) + extends AnalysisOnlyCommand with CTEInChildren { override protected def withNewChildrenInternal( newChildren: IndexedSeq[LogicalPlan]): CacheTableAsSelect = { assert(!isAnalyzed) @@ -1757,6 +1758,10 @@ case class CacheTableAsSelect( // Collect the referred temporary functions from AnalysisContext referredTempFunctions = ac.referredTempFunctionNames.toSeq) } + + override def withCTEDefs(cteDefs: Seq[CTERelationDef]): LogicalPlan = { + copy(plan = WithCTE(plan, cteDefs)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 00c9a26cb5bf3..040733294423c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTr import org.apache.spark.sql.catalyst.analysis.resolver.ResolverExtension import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Expression, ExtractSemiStructuredFields} +import org.apache.spark.sql.catalyst.normalizer.NormalizeCTEIds import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -403,6 +404,7 @@ abstract class BaseSessionStateBuilder( } protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = { + NormalizeCTEIds +: extensions.buildPlanNormalizationRules(session) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 12d26c4e195f1..880d8d72c73e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -2598,6 +2598,44 @@ class CachedTableSuite extends QueryTest with SQLTestUtils } } + test("SPARK-46741: Cache Table with CTE should work") { + withTempView("t1", "t2") { + sql( + """ + |CREATE TEMPORARY VIEW t1 + |AS + |SELECT * FROM VALUES (0, 0), (1, 1), (2, 2) AS t(c1, c2) + |""".stripMargin) + sql( + """ + |CREATE TEMPORARY VIEW t2 AS + |WITH v as ( + | SELECT c1 + c1 c3 FROM t1 + |) + |SELECT SUM(c3) s FROM v + |""".stripMargin) + sql( + """ + |CACHE TABLE cache_nested_cte_table + |WITH + |v AS ( + | SELECT c1 * c2 c3 from t1 + |) + |SELECT SUM(c3) FROM v + |EXCEPT + |SELECT s FROM t2 + |""".stripMargin) + + val df = sql("SELECT * FROM cache_nested_cte_table") + + val inMemoryTableScan = collect(df.queryExecution.executedPlan) { + case i: InMemoryTableScanExec => i + } + assert(inMemoryTableScan.size == 1) + checkAnswer(df, Row(5) :: Nil) + } + } + private def cacheManager = spark.sharedState.cacheManager private def pinTable(