Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,42 @@

package org.apache.spark.sql.catalyst.normalizer

import java.util.concurrent.atomic.AtomicLong

import scala.collection.mutable

import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE}
import org.apache.spark.sql.catalyst.rules.Rule

object NormalizeCTEIds extends Rule[LogicalPlan]{
object NormalizeCTEIds extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
val curId = new java.util.concurrent.atomic.AtomicLong()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we have unique CTE ids per query? then we need a new AtomicLong instance per apply invocation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this, directly change to transformDownWithSubqueries will cause UT SPARK-51109 failed.
For query

  test("SPARK-51109: CTE in subquery expression as grouping column") {
    withTable("t") {
      Seq(1 -> 1).toDF("c1", "c2").write.saveAsTable("t")
      withView("v") {
        sql(
          """
            |CREATE VIEW v AS
            |WITH r AS (SELECT c1 + c2 AS c FROM t)
            |SELECT * FROM r
            |""".stripMargin)
        checkAnswer(
          sql("SELECT (SELECT max(c) FROM v WHERE c > id) FROM range(1) GROUP BY 1"),
          Row(2)
        )
      }
    }
  }

Plan will be normalized from

Aggregate [scalar-subquery#15 [id#16L]], [scalar-subquery#15 [id#16L] AS scalarsubquery(id)#21]
:  :- Aggregate [max(c#18) AS max(c)#20]
:  :  +- Filter (cast(c#18 as bigint) > outer(id#16L))
:  :     +- SubqueryAlias spark_catalog.default.v
:  :        +- View (`spark_catalog`.`default`.`v`, [c#18])
:  :           +- Project [cast(c#17 as int) AS c#18]
:  :              +- WithCTE
:  :                 :- CTERelationDef 1, false
:  :                 :  +- SubqueryAlias r
:  :                 :     +- Project [(c1#12 + c2#13) AS c#17]
:  :                 :        +- SubqueryAlias spark_catalog.default.t
:  :                 :           +- Relation spark_catalog.default.t[c1#12,c2#13] parquet
:  :                 +- Project [c#17]
:  :                    +- SubqueryAlias r
:  :                       +- CTERelationRef 1, true, [c#17], false, false
:  +- Aggregate [max(c#26) AS max(c)#27]
:     +- Filter (cast(c#26 as bigint) > outer(id#16L))
:        +- SubqueryAlias spark_catalog.default.v
:           +- View (`spark_catalog`.`default`.`v`, [c#26])
:              +- Project [cast(c#25 as int) AS c#26]
:                 +- WithCTE
:                    :- CTERelationDef 1, false
:                    :  +- SubqueryAlias r
:                    :     +- Project [(c1#22 + c2#23) AS c#24]
:                    :        +- SubqueryAlias spark_catalog.default.t
:                    :           +- Relation spark_catalog.default.t[c1#22,c2#23] parquet
:                    +- Project [c#25]
:                       +- SubqueryAlias r
:                          +- CTERelationRef 1, true, [c#25], false, false
+- Range (0, 1, step=1)

to

Aggregate [scalar-subquery#15 [id#16L]], [scalar-subquery#15 [id#16L] AS scalarsubquery(id)#21]
:  :- Aggregate [max(c#18) AS max(c)#20]
:  :  +- Filter (cast(c#18 as bigint) > outer(id#16L))
:  :     +- SubqueryAlias spark_catalog.default.v
:  :        +- View (`spark_catalog`.`default`.`v`, [c#18])
:  :           +- Project [cast(c#17 as int) AS c#18]
:  :              +- WithCTE
:  :                 :- CTERelationDef 0, false
:  :                 :  +- SubqueryAlias r
:  :                 :     +- Project [(c1#12 + c2#13) AS c#17]
:  :                 :        +- SubqueryAlias spark_catalog.default.t
:  :                 :           +- Relation spark_catalog.default.t[c1#12,c2#13] parquet
:  :                 +- Project [c#17]
:  :                    +- SubqueryAlias r
:  :                       +- CTERelationRef 0, true, [c#17], false, false
:  +- Aggregate [max(c#26) AS max(c)#27]
:     +- Filter (cast(c#26 as bigint) > outer(id#16L))
:        +- SubqueryAlias spark_catalog.default.v
:           +- View (`spark_catalog`.`default`.`v`, [c#26])
:              +- Project [cast(c#25 as int) AS c#26]
:                 +- WithCTE
:                    :- CTERelationDef 1, false
:                    :  +- SubqueryAlias r
:                    :     +- Project [(c1#22 + c2#23) AS c#24]
:                    :        +- SubqueryAlias spark_catalog.default.t
:                    :           +- Relation spark_catalog.default.t[c1#22,c2#23] parquet
:                    +- Project [c#25]
:                       +- SubqueryAlias r
:                          +- CTERelationRef 1, true, [c#25], false, false
+- Range (0, 1, step=1)

in same plan the normalized cte id changed causing throw

[info]  is not a valid aggregate expression: [SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION] The correlated scalar subquery '"scalarsubquery(id)"' is neither present in GROUP BY, nor in an aggregate function.
[info] Add it to GROUP BY using ordinal position or wrap it in `first()` (or `first_value`) if you don't care which value you get. SQLSTATE: 0A000; line 1 pos 7
[info] Previous schema:scalarsubquery(id)#21

I am still trying how to fix such problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means we should handle the case when CTE def IDs can be duplicated. In such cases, we should not generate new IDs blindly.

val defIdToNewId = withCTE.cteDefs.map(_.id).map((_, curId.getAndIncrement())).toMap

We need to fix this line. The id map should be per apply invocation, not per WithCTE.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or use a global map in one traversal: #53333 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @cloud-fan How about current?

plan transformDown {
val cteIdToNewId = mutable.Map.empty[Long, Long]
applyInternal(plan, curId, cteIdToNewId)
}

private def applyInternal(
plan: LogicalPlan,
curId: AtomicLong,
cteIdToNewId: mutable.Map[Long, Long]): LogicalPlan = {
plan transformDownWithSubqueries {
case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) =>
ctas.copy(plan = apply(plan))
ctas.copy(plan = applyInternal(plan, curId, cteIdToNewId))

case withCTE @ WithCTE(plan, cteDefs) =>
val defIdToNewId = withCTE.cteDefs.map(_.id).map((_, curId.getAndIncrement())).toMap
val normalizedPlan = canonicalizeCTE(plan, defIdToNewId)
val newCteDefs = cteDefs.map { cteDef =>
val normalizedCteDef = canonicalizeCTE(cteDef.child, defIdToNewId)
cteDef.copy(child = normalizedCteDef, id = defIdToNewId(cteDef.id))
cteIdToNewId.getOrElseUpdate(cteDef.id, curId.getAndIncrement())
val normalizedCteDef = canonicalizeCTE(cteDef.child, cteIdToNewId)
cteDef.copy(child = normalizedCteDef, id = cteIdToNewId(cteDef.id))
}
val normalizedPlan = canonicalizeCTE(plan, cteIdToNewId)
withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs)
}
}

def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: Map[Long, Long]): LogicalPlan = {
private def canonicalizeCTE(
plan: LogicalPlan,
defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = {
plan.transformDownWithSubqueries {
// For nested WithCTE, if defIndex didn't contain the cteId,
// means it's not current WithCTE's ref.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2633,6 +2633,33 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
}
assert(inMemoryTableScan.size == 1)
checkAnswer(df, Row(5) :: Nil)

sql(
"""
|CACHE TABLE cache_subquery_cte_table
|WITH v AS (
| SELECT c1 * c2 c3 from t1
|)
|SELECT *
|FROM v
|WHERE EXISTS (
| WITH cte AS (SELECT 1 AS id)
| SELECT 1
| FROM cte
| WHERE cte.id = v.c3
|)
|""".stripMargin)

val cteInSubquery = sql(
"""
|SELECT * FROM cache_subquery_cte_table
|""".stripMargin)

val subqueryInMemoryTableScan = collect(cteInSubquery.queryExecution.executedPlan) {
case i: InMemoryTableScanExec => i
}
assert(subqueryInMemoryTableScan.size == 1)
checkAnswer(cteInSubquery, Row(1) :: Nil)
}
}

Expand Down