From 3a5d16b2cdc41a5122a254d5ad2dadde5c42573b Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Thu, 18 Dec 2025 17:46:41 +0800 Subject: [PATCH 01/13] [SPARK-46741][SQL] Cache Table with CTE should work when CTE in subquery --- .../catalyst/normalizer/NormalizeCTEIds.scala | 4 +-- .../apache/spark/sql/CachedTableSuite.scala | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index 1b1b526e78140..13000c61a89a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.rules.Rule object NormalizeCTEIds extends Rule[LogicalPlan]{ override def apply(plan: LogicalPlan): LogicalPlan = { val curId = new java.util.concurrent.atomic.AtomicLong() - plan transformDown { - + println(plan) + plan transformDownWithSubqueries { case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) => ctas.copy(plan = apply(plan)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 880d8d72c73e7..8b917b537e8da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -2633,6 +2633,31 @@ class CachedTableSuite extends QueryTest with SQLTestUtils } assert(inMemoryTableScan.size == 1) checkAnswer(df, Row(5) :: Nil) + + sql( + """ + |CACHE TABLE cache_subquery_cte_table + |SELECT * FROM ( + | WITH v AS ( + | SELECT c1 * c2 c3 from t1 + | ) + | SELECT SUM(c3) FROM v + |) + |EXCEPT + |SELECT s FROM t2 + |""".stripMargin) + + val cteInSubquery = sql( + """ + |SELECT * FROM cache_subquery_cte_table + |""".stripMargin) + + cteInSubquery.explain(true) + val subqueryInMemoryTableScan = collect(df.queryExecution.executedPlan) { + case i: InMemoryTableScanExec => i + } + assert(subqueryInMemoryTableScan.size == 1) + checkAnswer(cteInSubquery, Row(5) :: Nil) } } From 3c5db3f37dd701f26a9e9664ec903cfa15eff22d Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Thu, 18 Dec 2025 18:04:45 +0800 Subject: [PATCH 02/13] update --- .../catalyst/normalizer/NormalizeCTEIds.scala | 2 +- .../apache/spark/sql/CachedTableSuite.scala | 25 ------------------- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index 13000c61a89a5..e609f1b514d4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.rules.Rule object NormalizeCTEIds extends Rule[LogicalPlan]{ override def apply(plan: LogicalPlan): LogicalPlan = { val curId = new java.util.concurrent.atomic.AtomicLong() - println(plan) + plan transformDownWithSubqueries { case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) => ctas.copy(plan = apply(plan)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 8b917b537e8da..880d8d72c73e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -2633,31 +2633,6 @@ class CachedTableSuite extends QueryTest with SQLTestUtils } assert(inMemoryTableScan.size == 1) checkAnswer(df, Row(5) :: Nil) - - sql( - """ - |CACHE TABLE cache_subquery_cte_table - |SELECT * FROM ( - | WITH v AS ( - | SELECT c1 * c2 c3 from t1 - | ) - | SELECT SUM(c3) FROM v - |) - |EXCEPT - |SELECT s FROM t2 - |""".stripMargin) - - val cteInSubquery = sql( - """ - |SELECT * FROM cache_subquery_cte_table - |""".stripMargin) - - cteInSubquery.explain(true) - val subqueryInMemoryTableScan = collect(df.queryExecution.executedPlan) { - case i: InMemoryTableScanExec => i - } - assert(subqueryInMemoryTableScan.size == 1) - checkAnswer(cteInSubquery, Row(5) :: Nil) } } From 24870ed640d357ec3a67eee38b8965dc584f756c Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Thu, 18 Dec 2025 18:19:45 +0800 Subject: [PATCH 03/13] update --- .../apache/spark/sql/CachedTableSuite.scala | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 880d8d72c73e7..edfb6f79c8ff9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -2633,6 +2633,34 @@ class CachedTableSuite extends QueryTest with SQLTestUtils } assert(inMemoryTableScan.size == 1) checkAnswer(df, Row(5) :: Nil) + + sql( + """ + |CACHE TABLE cache_subquery_cte_table + |WITH v AS ( + | SELECT c1 * c2 c3 from t1 + |) + |SELECT * + |FROM v + |WHERE EXISTS ( + | WITH cte AS (SELECT 1 AS id) + | SELECT 1 + | FROM cte + | WHERE cte.id = v.c3 + |) + |""".stripMargin) + + val cteInSubquery = sql( + """ + |SELECT * FROM cache_subquery_cte_table + |""".stripMargin) + + cteInSubquery.explain(true) + val subqueryInMemoryTableScan = collect(cteInSubquery.queryExecution.executedPlan) { + case i: InMemoryTableScanExec => i + } + assert(subqueryInMemoryTableScan.size == 1) + checkAnswer(cteInSubquery, Row(1) :: Nil) } } From 2b093616d94744a6819af747ed76235a6b192d63 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Thu, 18 Dec 2025 18:20:05 +0800 Subject: [PATCH 04/13] Update NormalizeCTEIds.scala --- .../apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index e609f1b514d4f..f221625cfc14f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.rules.Rule object NormalizeCTEIds extends Rule[LogicalPlan]{ override def apply(plan: LogicalPlan): LogicalPlan = { val curId = new java.util.concurrent.atomic.AtomicLong() - plan transformDownWithSubqueries { case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) => ctas.copy(plan = apply(plan)) From 11ba02f546e600b8c5a85f9ffedb17ec22fdf446 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 18 Dec 2025 18:34:51 +0800 Subject: [PATCH 05/13] Update sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala --- .../src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index edfb6f79c8ff9..0d807aeae4d7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -2655,7 +2655,6 @@ class CachedTableSuite extends QueryTest with SQLTestUtils |SELECT * FROM cache_subquery_cte_table |""".stripMargin) - cteInSubquery.explain(true) val subqueryInMemoryTableScan = collect(cteInSubquery.queryExecution.executedPlan) { case i: InMemoryTableScanExec => i } From 4b7e1d8e1d5977a6109d5af36ff30ee056014d36 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Fri, 19 Dec 2025 19:25:36 +0800 Subject: [PATCH 06/13] Update NormalizeCTEIds.scala --- .../spark/sql/catalyst/normalizer/NormalizeCTEIds.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index f221625cfc14f..84ddfce27e57f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.catalyst.normalizer import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule -object NormalizeCTEIds extends Rule[LogicalPlan]{ +object NormalizeCTEIds extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { - val curId = new java.util.concurrent.atomic.AtomicLong() + plan transformDownWithSubqueries { case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) => ctas.copy(plan = apply(plan)) case withCTE @ WithCTE(plan, cteDefs) => + val curId = new java.util.concurrent.atomic.AtomicLong() val defIdToNewId = withCTE.cteDefs.map(_.id).map((_, curId.getAndIncrement())).toMap val normalizedPlan = canonicalizeCTE(plan, defIdToNewId) val newCteDefs = cteDefs.map { cteDef => From c09421a2c4b25a5140f517f1b110a3167ea7f296 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 22 Dec 2025 09:58:53 +0800 Subject: [PATCH 07/13] Revert "Update NormalizeCTEIds.scala" This reverts commit 4b7e1d8e1d5977a6109d5af36ff30ee056014d36. --- .../spark/sql/catalyst/normalizer/NormalizeCTEIds.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index 84ddfce27e57f..f221625cfc14f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.catalyst.normalizer import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule -object NormalizeCTEIds extends Rule[LogicalPlan] { +object NormalizeCTEIds extends Rule[LogicalPlan]{ override def apply(plan: LogicalPlan): LogicalPlan = { - + val curId = new java.util.concurrent.atomic.AtomicLong() plan transformDownWithSubqueries { case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) => ctas.copy(plan = apply(plan)) case withCTE @ WithCTE(plan, cteDefs) => - val curId = new java.util.concurrent.atomic.AtomicLong() val defIdToNewId = withCTE.cteDefs.map(_.id).map((_, curId.getAndIncrement())).toMap val normalizedPlan = canonicalizeCTE(plan, defIdToNewId) val newCteDefs = cteDefs.map { cteDef => From e65b59c52a653a46672b494df5843eb60bb175ae Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 22 Dec 2025 09:59:12 +0800 Subject: [PATCH 08/13] Update NormalizeCTEIds.scala --- .../apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index f221625cfc14f..5eb1ee2ecef03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.normalizer import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule -object NormalizeCTEIds extends Rule[LogicalPlan]{ +object NormalizeCTEIds extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { val curId = new java.util.concurrent.atomic.AtomicLong() plan transformDownWithSubqueries { From 960d8649483556e7e6971698402234ed4eb25a9e Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 22 Dec 2025 15:35:44 +0800 Subject: [PATCH 09/13] update --- .../catalyst/normalizer/NormalizeCTEIds.scala | 37 ++++++++++++++----- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index 5eb1ee2ecef03..954a16ca64071 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -17,37 +17,54 @@ package org.apache.spark.sql.catalyst.normalizer +import java.util.HashMap + import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule object NormalizeCTEIds extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { val curId = new java.util.concurrent.atomic.AtomicLong() + val defIdToNewId = new HashMap[Long, Long]() + + plan transformDownWithSubqueries { + case withCTE: WithCTE => + withCTE.cteDefs.foreach { cteDef => + if (!defIdToNewId.containsKey(cteDef.id)) { + defIdToNewId.put(cteDef.id, curId.getAndIncrement()) + } + } + withCTE + } + + applyInternal(plan, defIdToNewId) + } + + private def applyInternal(plan: LogicalPlan, defIdToNewId: HashMap[Long, Long]): LogicalPlan = { plan transformDownWithSubqueries { case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) => - ctas.copy(plan = apply(plan)) + ctas.copy(plan = applyInternal(plan, defIdToNewId)) case withCTE @ WithCTE(plan, cteDefs) => - val defIdToNewId = withCTE.cteDefs.map(_.id).map((_, curId.getAndIncrement())).toMap val normalizedPlan = canonicalizeCTE(plan, defIdToNewId) val newCteDefs = cteDefs.map { cteDef => val normalizedCteDef = canonicalizeCTE(cteDef.child, defIdToNewId) - cteDef.copy(child = normalizedCteDef, id = defIdToNewId(cteDef.id)) + cteDef.copy(child = normalizedCteDef, id = defIdToNewId.get(cteDef.id)) } withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs) } } - def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: Map[Long, Long]): LogicalPlan = { + private def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: HashMap[Long, Long]): LogicalPlan = { plan.transformDownWithSubqueries { // For nested WithCTE, if defIndex didn't contain the cteId, // means it's not current WithCTE's ref. - case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) => - ref.copy(cteId = defIdToNewId(ref.cteId)) - case unionLoop: UnionLoop if defIdToNewId.contains(unionLoop.id) => - unionLoop.copy(id = defIdToNewId(unionLoop.id)) - case unionLoopRef: UnionLoopRef if defIdToNewId.contains(unionLoopRef.loopId) => - unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId)) + case ref: CTERelationRef if defIdToNewId.containsKey(ref.cteId) => + ref.copy(cteId = defIdToNewId.get(ref.cteId)) + case unionLoop: UnionLoop if defIdToNewId.containsKey(unionLoop.id) => + unionLoop.copy(id = defIdToNewId.get(unionLoop.id)) + case unionLoopRef: UnionLoopRef if defIdToNewId.containsKey(unionLoopRef.loopId) => + unionLoopRef.copy(loopId = defIdToNewId.get(unionLoopRef.loopId)) } } } From 942a98f4d75c601c607eb56bb436cd76d6a0501f Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 22 Dec 2025 20:41:56 +0800 Subject: [PATCH 10/13] Update NormalizeCTEIds.scala --- .../catalyst/normalizer/NormalizeCTEIds.scala | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index 954a16ca64071..86718d2550ce8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.normalizer import java.util.HashMap +import java.util.concurrent.atomic.AtomicLong import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule @@ -26,31 +27,26 @@ object NormalizeCTEIds extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { val curId = new java.util.concurrent.atomic.AtomicLong() val defIdToNewId = new HashMap[Long, Long]() - - plan transformDownWithSubqueries { - case withCTE: WithCTE => - withCTE.cteDefs.foreach { cteDef => - if (!defIdToNewId.containsKey(cteDef.id)) { - defIdToNewId.put(cteDef.id, curId.getAndIncrement()) - } - } - withCTE - } - - applyInternal(plan, defIdToNewId) + applyInternal(plan, curId, defIdToNewId) } - private def applyInternal(plan: LogicalPlan, defIdToNewId: HashMap[Long, Long]): LogicalPlan = { + private def applyInternal( + plan: LogicalPlan, + curId: AtomicLong, + defIdToNewId: HashMap[Long, Long]): LogicalPlan = { plan transformDownWithSubqueries { case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) => - ctas.copy(plan = applyInternal(plan, defIdToNewId)) + ctas.copy(plan = applyInternal(plan, curId, defIdToNewId)) case withCTE @ WithCTE(plan, cteDefs) => - val normalizedPlan = canonicalizeCTE(plan, defIdToNewId) val newCteDefs = cteDefs.map { cteDef => + if (!defIdToNewId.containsKey(cteDef.id)) { + defIdToNewId.put(cteDef.id, curId.getAndIncrement()) + } val normalizedCteDef = canonicalizeCTE(cteDef.child, defIdToNewId) cteDef.copy(child = normalizedCteDef, id = defIdToNewId.get(cteDef.id)) } + val normalizedPlan = canonicalizeCTE(plan, defIdToNewId) withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs) } } From de652617f69172b1f6e482074b84b0273b7182a1 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 22 Dec 2025 20:57:04 +0800 Subject: [PATCH 11/13] Update NormalizeCTEIds.scala --- .../catalyst/normalizer/NormalizeCTEIds.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index 86718d2550ce8..a9e364b536820 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -26,27 +26,27 @@ import org.apache.spark.sql.catalyst.rules.Rule object NormalizeCTEIds extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { val curId = new java.util.concurrent.atomic.AtomicLong() - val defIdToNewId = new HashMap[Long, Long]() - applyInternal(plan, curId, defIdToNewId) + val cteIdToNewId = new HashMap[Long, Long]() + applyInternal(plan, curId, cteIdToNewId) } private def applyInternal( plan: LogicalPlan, curId: AtomicLong, - defIdToNewId: HashMap[Long, Long]): LogicalPlan = { + cteIdToNewId: HashMap[Long, Long]): LogicalPlan = { plan transformDownWithSubqueries { case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) => - ctas.copy(plan = applyInternal(plan, curId, defIdToNewId)) + ctas.copy(plan = applyInternal(plan, curId, cteIdToNewId)) case withCTE @ WithCTE(plan, cteDefs) => val newCteDefs = cteDefs.map { cteDef => - if (!defIdToNewId.containsKey(cteDef.id)) { - defIdToNewId.put(cteDef.id, curId.getAndIncrement()) + if (!cteIdToNewId.containsKey(cteDef.id)) { + cteIdToNewId.put(cteDef.id, curId.getAndIncrement()) } - val normalizedCteDef = canonicalizeCTE(cteDef.child, defIdToNewId) - cteDef.copy(child = normalizedCteDef, id = defIdToNewId.get(cteDef.id)) + val normalizedCteDef = canonicalizeCTE(cteDef.child, cteIdToNewId) + cteDef.copy(child = normalizedCteDef, id = cteIdToNewId.get(cteDef.id)) } - val normalizedPlan = canonicalizeCTE(plan, defIdToNewId) + val normalizedPlan = canonicalizeCTE(plan, cteIdToNewId) withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs) } } From ae16b0ff570752c8b5cbbd503a0facf015e683d5 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 22 Dec 2025 21:03:32 +0800 Subject: [PATCH 12/13] Update NormalizeCTEIds.scala --- .../catalyst/normalizer/NormalizeCTEIds.scala | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index a9e364b536820..3d67331665970 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -17,50 +17,53 @@ package org.apache.spark.sql.catalyst.normalizer -import java.util.HashMap import java.util.concurrent.atomic.AtomicLong +import scala.collection.mutable + import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule object NormalizeCTEIds extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { val curId = new java.util.concurrent.atomic.AtomicLong() - val cteIdToNewId = new HashMap[Long, Long]() + val cteIdToNewId = mutable.Map.empty[Long, Long] applyInternal(plan, curId, cteIdToNewId) } private def applyInternal( plan: LogicalPlan, curId: AtomicLong, - cteIdToNewId: HashMap[Long, Long]): LogicalPlan = { + cteIdToNewId: mutable.Map[Long, Long]): LogicalPlan = { plan transformDownWithSubqueries { case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) => ctas.copy(plan = applyInternal(plan, curId, cteIdToNewId)) case withCTE @ WithCTE(plan, cteDefs) => val newCteDefs = cteDefs.map { cteDef => - if (!cteIdToNewId.containsKey(cteDef.id)) { - cteIdToNewId.put(cteDef.id, curId.getAndIncrement()) + if (!cteIdToNewId.contains(cteDef.id)) { + cteIdToNewId(cteDef.id) = curId.getAndIncrement() } val normalizedCteDef = canonicalizeCTE(cteDef.child, cteIdToNewId) - cteDef.copy(child = normalizedCteDef, id = cteIdToNewId.get(cteDef.id)) + cteDef.copy(child = normalizedCteDef, id = cteIdToNewId(cteDef.id)) } val normalizedPlan = canonicalizeCTE(plan, cteIdToNewId) withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs) } } - private def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: HashMap[Long, Long]): LogicalPlan = { + private def canonicalizeCTE( + plan: LogicalPlan, + defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = { plan.transformDownWithSubqueries { // For nested WithCTE, if defIndex didn't contain the cteId, // means it's not current WithCTE's ref. - case ref: CTERelationRef if defIdToNewId.containsKey(ref.cteId) => - ref.copy(cteId = defIdToNewId.get(ref.cteId)) - case unionLoop: UnionLoop if defIdToNewId.containsKey(unionLoop.id) => - unionLoop.copy(id = defIdToNewId.get(unionLoop.id)) - case unionLoopRef: UnionLoopRef if defIdToNewId.containsKey(unionLoopRef.loopId) => - unionLoopRef.copy(loopId = defIdToNewId.get(unionLoopRef.loopId)) + case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) => + ref.copy(cteId = defIdToNewId(ref.cteId)) + case unionLoop: UnionLoop if defIdToNewId.contains(unionLoop.id) => + unionLoop.copy(id = defIdToNewId(unionLoop.id)) + case unionLoopRef: UnionLoopRef if defIdToNewId.contains(unionLoopRef.loopId) => + unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId)) } } } From 4bad19b95c9546f8ad5d9c5ebaf47d263acfd1f6 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Tue, 23 Dec 2025 10:25:52 +0800 Subject: [PATCH 13/13] follow comment --- .../spark/sql/catalyst/normalizer/NormalizeCTEIds.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index 3d67331665970..6c0bca0e1104f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -41,9 +41,7 @@ object NormalizeCTEIds extends Rule[LogicalPlan] { case withCTE @ WithCTE(plan, cteDefs) => val newCteDefs = cteDefs.map { cteDef => - if (!cteIdToNewId.contains(cteDef.id)) { - cteIdToNewId(cteDef.id) = curId.getAndIncrement() - } + cteIdToNewId.getOrElseUpdate(cteDef.id, curId.getAndIncrement()) val normalizedCteDef = canonicalizeCTE(cteDef.child, cteIdToNewId) cteDef.copy(child = normalizedCteDef, id = cteIdToNewId(cteDef.id)) }