diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index 1b1b526e78140..6c0bca0e1104f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -17,29 +17,42 @@ package org.apache.spark.sql.catalyst.normalizer +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable + import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule -object NormalizeCTEIds extends Rule[LogicalPlan]{ +object NormalizeCTEIds extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { val curId = new java.util.concurrent.atomic.AtomicLong() - plan transformDown { + val cteIdToNewId = mutable.Map.empty[Long, Long] + applyInternal(plan, curId, cteIdToNewId) + } + private def applyInternal( + plan: LogicalPlan, + curId: AtomicLong, + cteIdToNewId: mutable.Map[Long, Long]): LogicalPlan = { + plan transformDownWithSubqueries { case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) => - ctas.copy(plan = apply(plan)) + ctas.copy(plan = applyInternal(plan, curId, cteIdToNewId)) case withCTE @ WithCTE(plan, cteDefs) => - val defIdToNewId = withCTE.cteDefs.map(_.id).map((_, curId.getAndIncrement())).toMap - val normalizedPlan = canonicalizeCTE(plan, defIdToNewId) val newCteDefs = cteDefs.map { cteDef => - val normalizedCteDef = canonicalizeCTE(cteDef.child, defIdToNewId) - cteDef.copy(child = normalizedCteDef, id = defIdToNewId(cteDef.id)) + cteIdToNewId.getOrElseUpdate(cteDef.id, curId.getAndIncrement()) + val normalizedCteDef = canonicalizeCTE(cteDef.child, cteIdToNewId) + cteDef.copy(child = normalizedCteDef, id = cteIdToNewId(cteDef.id)) } + val normalizedPlan = canonicalizeCTE(plan, cteIdToNewId) withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs) } } - def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: Map[Long, Long]): LogicalPlan = { + private def canonicalizeCTE( + plan: LogicalPlan, + defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = { plan.transformDownWithSubqueries { // For nested WithCTE, if defIndex didn't contain the cteId, // means it's not current WithCTE's ref. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 880d8d72c73e7..0d807aeae4d7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -2633,6 +2633,33 @@ class CachedTableSuite extends QueryTest with SQLTestUtils } assert(inMemoryTableScan.size == 1) checkAnswer(df, Row(5) :: Nil) + + sql( + """ + |CACHE TABLE cache_subquery_cte_table + |WITH v AS ( + | SELECT c1 * c2 c3 from t1 + |) + |SELECT * + |FROM v + |WHERE EXISTS ( + | WITH cte AS (SELECT 1 AS id) + | SELECT 1 + | FROM cte + | WHERE cte.id = v.c3 + |) + |""".stripMargin) + + val cteInSubquery = sql( + """ + |SELECT * FROM cache_subquery_cte_table + |""".stripMargin) + + val subqueryInMemoryTableScan = collect(cteInSubquery.queryExecution.executedPlan) { + case i: InMemoryTableScanExec => i + } + assert(subqueryInMemoryTableScan.size == 1) + checkAnswer(cteInSubquery, Row(1) :: Nil) } }