diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4babd408403dd..bfeb5237daa38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -691,7 +691,8 @@ object CollapseProject extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p1 @ Project(_, p2: Project) => - if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { + if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) || + hasOversizedRepeatedAliases(p1.projectList, p2.projectList)) { p1 } else { p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) @@ -735,6 +736,28 @@ object CollapseProject extends Rule[LogicalPlan] { }.exists(!_.deterministic)) } + private def hasOversizedRepeatedAliases( + upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { + val aliases = collectAliases(lower) + + // Count how many times each alias is used in the upper Project. + // If an alias is only used once, we can safely substitute it without increasing the overall + // tree size + val referenceCounts = AttributeMap( + upper + .flatMap(_.collect { case a: Attribute => a }) + .groupBy(identity) + .mapValues(_.size).toSeq + ) + + // Check for any aliases that are used more than once, and are larger than the configured + // maximum size + aliases.exists({ case (attribute, expression) => + referenceCounts.getOrElse(attribute, 0) > 1 && + expression.treeSize > SQLConf.get.maxRepeatedAliasSize + }) + } + private def buildCleanedProjectList( upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Seq[NamedExpression] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index a816922f49aee..e29b7b9eae316 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf /** * A pattern that matches any number of project or filter operations on top of another relational @@ -60,8 +61,13 @@ object PhysicalOperation extends PredicateHelper { plan match { case Project(fields, child) if fields.forall(_.deterministic) => val (_, filters, other, aliases) = collectProjectsAndFilters(child) - val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] - (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) + if (hasOversizedRepeatedAliases(fields, aliases)) { + // Skip substitution if it could overly increase the overall tree size and risk OOMs + (None, Nil, plan, Map.empty) + } else { + val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] + (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) + } case Filter(condition, child) if condition.deterministic => val (fields, filters, other, aliases) = collectProjectsAndFilters(child) @@ -79,6 +85,26 @@ object PhysicalOperation extends PredicateHelper { case a @ Alias(child, _) => a.toAttribute -> child }.toMap + private def hasOversizedRepeatedAliases(fields: Seq[Expression], + aliases: Map[Attribute, Expression]): Boolean = { + // Count how many times each alias is used in the fields. + // If an alias is only used once, we can safely substitute it without increasing the overall + // tree size + val referenceCounts = AttributeMap( + fields + .flatMap(_.collect { case a: Attribute => a }) + .groupBy(identity) + .mapValues(_.size).toSeq + ) + + // Check for any aliases that are used more than once, and are larger than the configured + // maximum size + aliases.exists({ case (attribute, expression) => + referenceCounts.getOrElse(attribute, 0) > 1 && + expression.treeSize > SQLConf.get.maxRepeatedAliasSize + }) + } + private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { expr.transform { case a @ Alias(ref: AttributeReference, name) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 72b1931e47cf2..f25d8a98b9185 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -89,6 +89,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { lazy val containsChild: Set[TreeNode[_]] = children.toSet + lazy val treeSize: Long = children.map(_.treeSize).sum + 1 + private lazy val _hashCode: Int = scala.util.hashing.MurmurHash3.productHash(this) override def hashCode(): Int = _hashCode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 20f4080c98590..ef9ead1809532 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1717,6 +1717,15 @@ object SQLConf { "and java.sql.Date are used for the same purpose.") .booleanConf .createWithDefault(false) + + val MAX_REPEATED_ALIAS_SIZE = + buildConf("spark.sql.maxRepeatedAliasSize") + .internal() + .doc("The maximum size of alias expression that will be substituted multiple times " + + "(size defined by the number of nodes in the expression tree). " + + "Used by the CollapseProject optimizer, and PhysicalOperation.") + .intConf + .createWithDefault(100) } /** @@ -2162,6 +2171,8 @@ class SQLConf extends Serializable with Logging { def setCommandRejectsSparkCoreConfs: Boolean = getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS) + def maxRepeatedAliasSize: Int = getConf(SQLConf.MAX_REPEATED_ALIAS_SIZE) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 42bcd13ee378d..4294f7e568ebb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -170,4 +170,22 @@ class CollapseProjectSuite extends PlanTest { val expected = Sample(0.0, 0.6, false, 11L, relation.select('a as 'c)).analyze comparePlans(optimized, expected) } + + test("ensure oversize aliases are not repeatedly substituted") { + var query: LogicalPlan = testRelation + for( a <- 1 to 100) { + query = query.select(('a + 'b).as('a), ('a - 'b).as('b)) + } + val projects = Optimize.execute(query.analyze).collect { case p: Project => p } + assert(projects.size >= 12) + } + + test("ensure oversize aliases are still substituted once") { + var query: LogicalPlan = testRelation + for( a <- 1 to 20) { + query = query.select(('a + 'b).as('a), 'b) + } + val projects = Optimize.execute(query.analyze).collect { case p: Project => p } + assert(projects.size === 1) + } }