Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c151b60
[SPARK-46741][SQL] Cache Table with CET won't work
AngersZhuuuu Jan 17, 2024
b9711ab
Update basicLogicalOperators.scala
AngersZhuuuu Jan 17, 2024
38c68d8
Update basicLogicalOperators.scala
AngersZhuuuu Jan 19, 2024
d587e6d
update
AngersZhuuuu Jan 23, 2024
8b2a25d
Merge branch 'master' into SPARK-46741
AngersZhuuuu Feb 2, 2024
bb15d32
Follow comment to add normalize rule
AngersZhuuuu May 8, 2024
fa17613
Merge branch 'master' into SPARK-46741
AngersZhuuuu May 8, 2024
a6e2657
Merge branch 'master' into SPARK-46741
AngersZhuuuu May 8, 2024
c9017ef
Update QueryExecution.scala
AngersZhuuuu May 8, 2024
4aa6aee
Update cache.sql.out
AngersZhuuuu May 8, 2024
99bf379
Merge branch 'master' into SPARK-46741
AngersZhuuuu May 14, 2024
054231f
follow comment
AngersZhuuuu May 14, 2024
92a213d
Update cache.sql.out
AngersZhuuuu May 14, 2024
123a986
Revert "follow comment"
AngersZhuuuu May 14, 2024
611b1b4
Merge branch 'master' into SPARK-46741
AngersZhuuuu Dec 5, 2025
5448b73
Update BaseSessionStateBuilder.scala
AngersZhuuuu Dec 5, 2025
99ee735
update
AngersZhuuuu Dec 5, 2025
6a92b32
Update cte-recursion.sql.out
AngersZhuuuu Dec 5, 2025
5770ff3
Update WithCTENormalized.scala
AngersZhuuuu Dec 12, 2025
7243d1f
Update WithCTENormalized.scala
AngersZhuuuu Dec 12, 2025
40a4332
Update WithCTENormalized.scala
AngersZhuuuu Dec 14, 2025
95bcf5d
Update WithCTENormalized.scala
AngersZhuuuu Dec 16, 2025
2b170fd
Merge branch 'master' into SPARK-46741
AngersZhuuuu Dec 16, 2025
1adab54
Update
AngersZhuuuu Dec 17, 2025
f99feb1
Update SQLQuerySuite.scala
AngersZhuuuu Dec 17, 2025
de0a94d
follow comment
AngersZhuuuu Dec 18, 2025
d10fd2b
update
AngersZhuuuu Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.normalizer

import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE}
import org.apache.spark.sql.catalyst.rules.Rule

object NormalizeCTEIds extends Rule[LogicalPlan]{
override def apply(plan: LogicalPlan): LogicalPlan = {
val curId = new java.util.concurrent.atomic.AtomicLong()
plan transformDown {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this rule normalize ids of WithCTE nodes in subquery expressions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah missed it, we should use transformDownWithSubqueries here as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AngersZhuuuu can you create a followup?


case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) =>
ctas.copy(plan = apply(plan))

case withCTE @ WithCTE(plan, cteDefs) =>
val defIdToNewId = withCTE.cteDefs.map(_.id).map((_, curId.getAndIncrement())).toMap
Copy link
Contributor

@peter-toth peter-toth Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we use an AtomicLong.getAndIncrement() here? Could a simple var work here?
Anyways, this is just a nit.

Copy link
Contributor

@peter-toth peter-toth Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, as you normalize WithCTE nodes one by one, but the nodes can be nested, once you normalized one of them, can the normalized ids conflict with the not yet normalized ids of an other node?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't come up with an exact example where processing WithCTE nodes one by one is an issue, but I wonder if traversing the whole plan and maintaining a global replacement map would be safer and more efficient solution.

val normalizedPlan = canonicalizeCTE(plan, defIdToNewId)
val newCteDefs = cteDefs.map { cteDef =>
val normalizedCteDef = canonicalizeCTE(cteDef.child, defIdToNewId)
cteDef.copy(child = normalizedCteDef, id = defIdToNewId(cteDef.id))
}
withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs)
}
}

def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: Map[Long, Long]): LogicalPlan = {
plan.transformDownWithSubqueries {
// For nested WithCTE, if defIndex didn't contain the cteId,
// means it's not current WithCTE's ref.
case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) =>
ref.copy(cteId = defIdToNewId(ref.cteId))
case unionLoop: UnionLoop if defIdToNewId.contains(unionLoop.id) =>
unionLoop.copy(id = defIdToNewId(unionLoop.id))
case unionLoopRef: UnionLoopRef if defIdToNewId.contains(unionLoopRef.loopId) =>
unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1742,7 +1742,8 @@ case class CacheTableAsSelect(
isLazy: Boolean,
options: Map[String, String],
isAnalyzed: Boolean = false,
referredTempFunctions: Seq[String] = Seq.empty) extends AnalysisOnlyCommand {
referredTempFunctions: Seq[String] = Seq.empty)
extends AnalysisOnlyCommand with CTEInChildren {
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): CacheTableAsSelect = {
assert(!isAnalyzed)
Expand All @@ -1757,6 +1758,10 @@ case class CacheTableAsSelect(
// Collect the referred temporary functions from AnalysisContext
referredTempFunctions = ac.referredTempFunctionNames.toSeq)
}

override def withCTEDefs(cteDefs: Seq[CTERelationDef]): LogicalPlan = {
copy(plan = WithCTE(plan, cteDefs))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTr
import org.apache.spark.sql.catalyst.analysis.resolver.ResolverExtension
import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExtractSemiStructuredFields}
import org.apache.spark.sql.catalyst.normalizer.NormalizeCTEIds
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -403,6 +404,7 @@ abstract class BaseSessionStateBuilder(
}

protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = {
NormalizeCTEIds +:
extensions.buildPlanNormalizationRules(session)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2598,6 +2598,44 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
}
}

test("SPARK-46741: Cache Table with CTE should work") {
withTempView("t1", "t2") {
sql(
"""
|CREATE TEMPORARY VIEW t1
|AS
|SELECT * FROM VALUES (0, 0), (1, 1), (2, 2) AS t(c1, c2)
|""".stripMargin)
sql(
"""
|CREATE TEMPORARY VIEW t2 AS
|WITH v as (
| SELECT c1 + c1 c3 FROM t1
|)
|SELECT SUM(c3) s FROM v
|""".stripMargin)
sql(
"""
|CACHE TABLE cache_nested_cte_table
|WITH
|v AS (
| SELECT c1 * c2 c3 from t1
|)
|SELECT SUM(c3) FROM v
|EXCEPT
|SELECT s FROM t2
|""".stripMargin)

val df = sql("SELECT * FROM cache_nested_cte_table")

val inMemoryTableScan = collect(df.queryExecution.executedPlan) {
case i: InMemoryTableScanExec => i
}
assert(inMemoryTableScan.size == 1)
checkAnswer(df, Row(5) :: Nil)
}
}

private def cacheManager = spark.sharedState.cacheManager

private def pinTable(
Expand Down