From 44d1de6ae0c8ab95e58dfd66039973cc8afe3aa8 Mon Sep 17 00:00:00 2001 From: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Date: Tue, 6 Apr 2021 18:16:40 -0700 Subject: [PATCH 1/5] dependent-join --- .../sql/catalyst/expressions/subquery.scala | 30 +- .../sql/catalyst/optimizer/subquery.scala | 386 +++++++++++++++++- .../plans/logical/basicLogicalOperators.scala | 9 + .../apache/spark/sql/internal/SQLConf.scala | 10 + .../DecorrelateInnerQuerySuite.scala | 282 +++++++++++++ 5 files changed, 678 insertions(+), 39 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 4e07e729ab3bb..dd24362aa80dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -140,14 +140,11 @@ object SubExprUtils extends PredicateHelper { * Given a logical plan, returns TRUE if it has an outer reference and false otherwise. */ def hasOuterReferences(plan: LogicalPlan): Boolean = { - plan.find { - case f: Filter => containsOuter(f.condition) - case other => false - }.isDefined + plan.find(_.expressions.exists(containsOuter)).isDefined } /** - * Given a list of expressions, returns the expressions which have outer references. Aggregate + * Given an expression, returns the expressions which have outer references. Aggregate * expressions are treated in a special way. If the children of aggregate expression contains an * outer reference, then the entire aggregate expression is marked as an outer reference. * Example (SQL): @@ -183,18 +180,16 @@ object SubExprUtils extends PredicateHelper { * }}} * The code below needs to change when we support the above cases. */ - def getOuterReferences(conditions: Seq[Expression]): Seq[Expression] = { + def getOuterReferences(condition: Expression): Seq[Expression] = { val outerExpressions = ArrayBuffer.empty[Expression] - conditions foreach { expr => - expr transformDown { - case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) => - val newExpr = stripOuterReference(a) - outerExpressions += newExpr - newExpr - case OuterReference(e) => - outerExpressions += e - e - } + condition transformDown { + case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) => + val newExpr = stripOuterReference(a) + outerExpressions += newExpr + newExpr + case OuterReference(e) => + outerExpressions += e + e } outerExpressions.toSeq } @@ -204,8 +199,7 @@ object SubExprUtils extends PredicateHelper { * Filter operator can host outer references. */ def getOuterReferences(plan: LogicalPlan): Seq[Expression] = { - val conditions = plan.collect { case Filter(cond, _) => cond } - getOuterReferences(conditions) + plan.flatMap(_.expressions.flatMap(getOuterReferences)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index ef73e58645a89..aa06004821ac1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /* @@ -272,22 +273,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper val baseConditions = predicateMap.values.flatten.toSeq val (newPlan, newCond) = if (outer.nonEmpty) { val outputSet = outer.map(_.outputSet).reduce(_ ++ _) - val duplicates = transformed.outputSet.intersect(outputSet) - val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { - val aliasMap = AttributeMap(duplicates.map { dup => - dup -> Alias(dup, dup.toString)() - }.toSeq) - val aliasedExpressions = transformed.output.map { ref => - aliasMap.getOrElse(ref, ref) - } - val aliasedProjection = Project(aliasedExpressions, transformed) - val aliasedConditions = baseConditions.map(_.transform { - case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute - }) - (aliasedProjection, aliasedConditions) - } else { - (transformed, baseConditions) - } + val (plan, deDuplicatedConditions) = + DecorrelateInnerQuery.deduplicate(transformed, baseConditions, outputSet) (plan, stripOuterReferences(deDuplicatedConditions)) } else { (transformed, stripOuterReferences(baseConditions)) @@ -308,9 +295,17 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper if (newCond.isEmpty) oldCond else newCond } + def rewrite(sub: LogicalPlan, outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + if (SQLConf.get.decorrelateInnerQueryEnabled) { + DecorrelateInnerQuery(sub, outer) + } else { + pullOutCorrelatedPredicates(sub, outer) + } + } + plan transformExpressions { case ScalarSubquery(sub, children, exprId) if children.nonEmpty => - val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + val (newPlan, newCond) = rewrite(sub, outerPlans) ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId) case Exists(sub, children, exprId) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) @@ -379,7 +374,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe bindings: Map[ExprId, Expression]): Expression = { val rewrittenExpr = expr transform { case r: AttributeReference => - bindings.getOrElse(r.exprId, Literal.default(NullType)) + bindings.getOrElse(r.exprId, Literal.create(null, r.dataType)) } tryEvalExpr(rewrittenExpr) @@ -394,9 +389,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe // Also replace attribute refs (for example, for grouping columns) with NULL. val rewrittenExpr = expr transform { case a @ AggregateExpression(aggFunc, _, _, resultId, _) => - aggFunc.defaultResult.getOrElse(Literal.default(NullType)) + aggFunc.defaultResult.getOrElse(Literal.create(null, aggFunc.dataType)) - case _: AttributeReference => Literal.default(NullType) + case a: AttributeReference => Literal.create(null, a.dataType) } tryEvalExpr(rewrittenExpr) @@ -514,6 +509,56 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe // Name of generated column used in rewrite below val ALWAYS_TRUE_COLNAME = "alwaysTrue" + /** + * Build a mapping between domain attributes and corresponding outer query expressions + * using the join conditions. + */ + private def buildDomainAttrMap( + conditions: Seq[Expression], + domainAttrs: Seq[Attribute]): Map[Attribute, Expression] = { + val outputSet = AttributeSet(domainAttrs) + conditions.collect { + // When we build the equality conditions, the left side is always the + // domain attributes used in the inner plan, and the right side is the + // attribute from outer plan. Note the right hand side is not necessarily + // an attribute, for example it can be a literal (if foldable) or a cast expression. + case EqualNullSafe(left: Attribute, right: Expression) if outputSet.contains(left) => + left -> right + }.toMap + } + + /** + * Rewrite domain join placeholder to actual inner joins. + */ + private def rewriteDomainJoins( + outerPlan: LogicalPlan, + innerPlan: LogicalPlan, + conditions: Seq[Expression]): LogicalPlan = { + innerPlan transform { + case d @ DomainJoin(domainAttrs, child) => + val domainAttrMap = buildDomainAttrMap(conditions, domainAttrs) + // We should only rewrite a domain join when all corresponding outer plan attributes + // can be found from the join condition. + if (domainAttrMap.size == domainAttrs.size) { + val groupingExprs = domainAttrs.map(domainAttrMap) + val aggregateExprs = groupingExprs.zip(domainAttrs).map { + // Rebuild the aliases. + case (inputAttr, outputAttr) => Alias(inputAttr, outputAttr.name)(outputAttr.exprId) + } + val domain = Aggregate(groupingExprs, aggregateExprs, outerPlan) + child match { + // A special optimization for OneRowRelation. + // TODO: add a more general rule to optimize join with OneRowRelation. + case _: OneRowRelation => domain + case _ => Join(child, domain, Inner, None, JoinHint.NONE) + } + } else { + throw new UnsupportedOperationException( + s"Unable to rewrite domain join with conditions: $conditions\n$d") + } + } + } + /** * Construct a new child plan by left joining the given subqueries to a base plan. * This method returns the child plan and an attribute mapping @@ -525,7 +570,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = { val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() val newChild = subqueries.foldLeft(child) { - case (currentChild, ScalarSubquery(query, conditions, _)) => + case (currentChild, ScalarSubquery(sub, conditions, _)) => + val query = rewriteDomainJoins(currentChild, sub, conditions) val origOutput = query.output.head val resultWithZeroTups = evalSubqueryOnZeroTups(query) @@ -678,3 +724,301 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe } } } + +/** + * Decorrelate the inner query by eliminating outer references and create domain joins. + * The implementation is based on the paper: Unnesting Arbitrary Queries by Thomas Neumann + * and Alfons Kemper. https://dl.gi.de/handle/20.500.12116/2418. + * (1) Recursively collects outer references from the inner query until it reaches a node + * that does not contain correlated value. + * (2) Inserts an optional [[DomainJoin]] node to indicate whether a domain (inner) join is + * needed between the outer query and the specific subtree of the inner query. + * (3) Returns a list of join conditions with the outer query and a mapping between outer + * references with references inside the inner query. The parent nodes need to preserve + * the references inside the join conditions and substitute all outer references using + * the mapping. + * + * E.g. decorrelate an inner query with equality predicates: + * + * Aggregate [] [min(c2)] Aggregate [c1] [min(c2), c1] + * +- Filter [outer(c3) = c1] => +- Relation [t] + * +- Relation [t] + * + * Join conditions: [c3 = c1] + * + * E.g. decorrelate an inner query with non-equality predicates: + * + * Aggregate [] [min(c2)] Aggregate [c3'] [min(c2), c3'] + * +- Filter [outer(c3) > c1] => +- Filter [c3' > c1] + * +- Relation [t] +- DomainJoin [c3'] + * +- Relation [t] + * + * Join conditions: [c3 <=> c3'] + */ +object DecorrelateInnerQuery extends PredicateHelper { + + /** + * Check if the given expression is an equality condition. + */ + private def isEquality(expression: Expression): Boolean = expression match { + case Equality(_, _) => true + case _ => false + } + + /** + * Collect outer references in an expressions that are in the output attributes of the outer plan. + */ + private def collectOuterReferences(expression: Expression): AttributeSet = { + AttributeSet(expression.collect { case o: OuterReference => o.toAttribute }) + } + + /** + * Collect outer references in a sequence of expressions that are in the output attributes + * of the outer plan. + */ + private def collectOuterReferences(expressions: Seq[Expression]): AttributeSet = { + AttributeSet.fromAttributeSets(expressions.map(collectOuterReferences)) + } + + /** + * Build a mapping between outer references with equivalent inner query attributes. + * E.g. [outer(a) = x, y = outer(b), outer(c) = z + 1] => {a -> x, b -> y} + */ + private def collectEquivalentOuterReferences( + expressions: Seq[Expression]): Map[Attribute, Attribute] = { + expressions.collect { + case Equality(o: OuterReference, a: Attribute) => (o.toAttribute, a.toAttribute) + case Equality(a: Attribute, o: OuterReference) => (o.toAttribute, a.toAttribute) + }.toMap + } + + /** + * Replace all outer references using the expressions in the given outer reference map. + */ + private def replaceOuterReference[E <: Expression]( + expression: E, + outerReferenceMap: Map[Attribute, Attribute]): E = { + expression.transform { + case o: OuterReference => outerReferenceMap.getOrElse(o.toAttribute, o) + }.asInstanceOf[E] + } + + /** + * Replace all outer references in the given expressions using the expressions in the + * outer reference map. + */ + private def replaceOuterReferences[E <: Expression]( + expressions: Seq[E], + outerReferenceMap: Map[Attribute, Attribute]): Seq[E] = { + expressions.map(replaceOuterReference(_, outerReferenceMap)) + } + + /** + * Return all missing references of the attribute set from the required attributes + * in the join condition. + */ + private def missingReferences( + expressions: Seq[Expression], + joinCond: Seq[Expression]): AttributeSet = { + val outputSet = AttributeSet(expressions) + AttributeSet(joinCond.flatMap(_.references)) -- outputSet + } + + /** + * Deduplicate the inner and the outer query attributes and return an aliased + * subquery plan and join conditions if duplicates are found. Duplicated attributes + * can break the structural integrity when joining the inner and outer plan together. + */ + def deduplicate( + innerPlan: LogicalPlan, + conditions: Seq[Expression], + outerOutputSet: AttributeSet): (LogicalPlan, Seq[Expression]) = { + val duplicates = innerPlan.outputSet.intersect(outerOutputSet) + if (duplicates.nonEmpty) { + val aliasMap = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = innerPlan.output.map { ref => + aliasMap.getOrElse(ref, ref) + } + val aliasedProjection = Project(aliasedExpressions, innerPlan) + val aliasedConditions = conditions.map(_.transform { + case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute + }) + (aliasedProjection, aliasedConditions) + } else { + (innerPlan, conditions) + } + } + + def apply( + innerPlan: LogicalPlan, + outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = { + apply(innerPlan, Seq(outerPlan)) + } + + def apply( + innerPlan: LogicalPlan, + outerPlans: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + val outputSet = AttributeSet(outerPlans.flatMap(_.outputSet)) + + // The return type of the recursion. + // The first parameter is a new logical plan with correlation eliminated. + // The second parameter is a list of join conditions with the outer query. + // The third parameter is a mapping between the outer references and equivalent + // expressions from the inner query that is used to replace outer references. + type ReturnType = (LogicalPlan, Seq[Expression], Map[Attribute, Attribute]) + + // Recursively decorrelate the input plan with a set of parent outer references and + // a boolean flag indicating whether the result of the plan will be aggregated. + def decorrelate( + plan: LogicalPlan, + parentOuterReferences: AttributeSet, + aggregated: Boolean = false): ReturnType = { + val isCorrelated = hasOuterReferences(plan) + if (!isCorrelated) { + // We have reached a plan without correlation to the outer plan. + if (parentOuterReferences.isEmpty) { + // If there is no outer references from the parent nodes, it means all outer + // attributes can be substituted by attributes from the inner plan. So no + // domain join is needed. + (plan, Nil, Map.empty[Attribute, Attribute]) + } else { + // Build the domain join with the parent outer references. + val attributes = parentOuterReferences.toSeq + val domains = attributes.map(_.newInstance()) + // A placeholder to be rewritten into domain join. + val domainJoin = DomainJoin(domains, plan) + val outerReferenceMap = attributes.zip(domains).toMap + // Build join conditions between domain attributes and outer references. + // EqualNullSafe is used to make sure null key can be joined together. Note + // outer referenced attributes can be changed during the outer query optimization. + // The equality conditions will also serve as an attribute mapping between new + // outer references and domain attributes when rewriting the domain joins. + // E.g. if the attribute a is changed to a1, the join condition a' <=> outer(a) + // will become a' <=> a1, and we can construct the aliases based on the condition: + // DomainJoin [a'] Join Inner + // +- InnerQuery => :- InnerQuery + // +- Aggregate [a1] [a1 AS a'] + // +- OuterQuery + val conditions = outerReferenceMap.map { + case (o, a) => EqualNullSafe(a, OuterReference(o)) + } + (domainJoin, conditions.toSeq, outerReferenceMap) + } + } else { + // Collect outer references from the current node. + val outerReferences = collectOuterReferences(plan.expressions) + plan match { + case Filter(condition, child) => + val (correlated, uncorrelated) = + splitConjunctivePredicates(condition) + .partition(containsOuter) + val (equality, nonEquality) = correlated.partition(isEquality) + // Find equivalent outer reference relations and remove equivalent attributes from + // parentOuterReferences since they can be replaced directly by expressions + // inside the inner plan. + val equivalences = collectEquivalentOuterReferences(equality) + // When the results are aggregated, outer references inside the non-equality + // predicates cannot be used directly as join conditions with the outer query. + val outerReferences = if (aggregated) { + collectOuterReferences(nonEquality) + } else { + AttributeSet.empty + } + val newOuterReferences = parentOuterReferences ++ outerReferences -- equivalences.keySet + val (newChild, joinCond, outerReferenceMap) = + decorrelate(child, newOuterReferences, aggregated) + // Add the mapping from the current node. + val newOuterReferenceMap = outerReferenceMap ++ equivalences + // Replace all outer references in non-equality filter conditions using the domain + // attributes produced for inner query with aggregates. This step is necessary + // for pushing down the non-equality filters into the domain join as join conditions. + val (newFilterCond, newJoinCond) = if (aggregated) { + val nonEqualityCond = replaceOuterReferences(nonEquality, newOuterReferenceMap) + (nonEqualityCond ++ uncorrelated, equality) + } else { + (uncorrelated, correlated) + } + val newFilter = newFilterCond match { + case Nil => newChild + case xs => Filter(xs.reduce(And), newChild) + } + (newFilter, joinCond ++ newJoinCond, newOuterReferenceMap) + + case Project(projectList, child) => + val newOuterReferences = parentOuterReferences ++ outerReferences + val (newChild, joinCond, outerReferenceMap) = + decorrelate(child, newOuterReferences, aggregated) + // Replace all outer references in the original project list. + val newProjectList = replaceOuterReferences(projectList, outerReferenceMap) + // Preserve required domain attributes in the join condition by adding the missing + // references to the new project list. + val referencesToAdd = missingReferences(newProjectList.map(_.toAttribute), joinCond) + val newProject = Project(newProjectList ++ referencesToAdd, newChild) + (newProject, joinCond, outerReferenceMap) + + case a @ Aggregate(groupingExpressions, aggregateExpressions, child) => + val newOuterReferences = parentOuterReferences ++ outerReferences + val (newChild, joinCond, outerReferenceMap) = + decorrelate(child, newOuterReferences, aggregated = true) + // Replace all outer references in grouping and aggregate expressions. + val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap) + val newAggExpr = replaceOuterReferences(aggregateExpressions, outerReferenceMap) + // Add all required domain attributes to both grouping and aggregate expressions. + val groupingExprToAdd = missingReferences(newGroupingExpr, joinCond) + val aggExprToAdd = missingReferences(newAggExpr.map(_.toAttribute), joinCond) + val newAggregate = a.copy( + groupingExpressions = newGroupingExpr ++ groupingExprToAdd, + aggregateExpressions = newAggExpr ++ aggExprToAdd, + child = newChild) + (newAggregate, joinCond, outerReferenceMap) + + case j @ Join(left, right, joinType, condition, _) => + // Join condition containing outer references is not supported. + assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j") + val newOuterReferences = parentOuterReferences ++ outerReferences + val shouldPushToLeft = joinType match { + case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true + case _ => hasOuterReferences(left) + } + val shouldPushToRight = joinType match { + case RightOuter | FullOuter => true + case _ => hasOuterReferences(right) + } + val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) { + decorrelate(left, newOuterReferences, aggregated) + } else { + (left, Nil, Map.empty[Attribute, Attribute]) + } + val (newRight, rightJoinCond, rightOuterReferenceMap) = if (shouldPushToRight) { + decorrelate(right, newOuterReferences, aggregated) + } else { + (right, Nil, Map.empty[Attribute, Attribute]) + } + val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap + val newJoinCond = leftJoinCond ++ rightJoinCond + // If we push the dependent join to both sides, we will need to augment the join + // condition such that both sides are matched on the domain attributes. + val augmentedConditions = leftOuterReferenceMap.flatMap { + case (outer, inner) => rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _)) + } + val newCondition = (condition ++ augmentedConditions).reduceOption(And) + val newJoin = j.copy(left = newLeft, right = newRight, condition = newCondition) + (newJoin, newJoinCond, newOuterReferenceMap) + + case s: UnaryNode => + assert(outerReferences.isEmpty, s"Correlated column is not allowed in $s") + decorrelate(s.child, parentOuterReferences, aggregated) + + case o => + throw new UnsupportedOperationException( + s"Push down dependent joins through $o is not supported.") + } + } + } + val (newChild, joinCond, _) = decorrelate(BooleanSimplification(innerPlan), AttributeSet.empty) + val (plan, conditions) = deduplicate(newChild, joinCond, outputSet) + (plan, stripOuterReferences(conditions)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 21e87b4c62606..11e5482c5ecdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1393,3 +1393,12 @@ case class CollectMetrics( override protected def withNewChildInternal(newChild: LogicalPlan): CollectMetrics = copy(child = newChild) } + +/** + * A placeholder for domain join that can be added when decorrelating subqueries. + * It should be rewritten during the optimization phase. + */ +case class DomainJoin(domainAttrs: Seq[Attribute], child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output ++ domainAttrs + override def producedAttributes: AttributeSet = AttributeSet(domainAttrs) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f4c236c68dfe9..04e740039f005 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2427,6 +2427,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DECORRELATE_INNER_QUERY_ENABLED = + buildConf("spark.sql.optimizer.decorrelateInnerQuery.enabled") + .internal() + .doc("Decorrelate inner query by eliminating correlated references and build domain joins.") + .version("3.2.0") + .booleanConf + .createWithDefault(true) + val TOP_K_SORT_FALLBACK_THRESHOLD = buildConf("spark.sql.execution.topKSortFallbackThreshold") .internal() @@ -3829,6 +3837,8 @@ class SQLConf extends Serializable with Logging { def legacyIntervalEnabled: Boolean = getConf(LEGACY_INTERVAL_ENABLED) + def decorrelateInnerQueryEnabled: Boolean = getConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala new file mode 100644 index 0000000000000..f58e473728caf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.IntegerType + +class DecorrelateInnerQuerySuite extends PlanTest { + + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + val c = AttributeReference("c", IntegerType)() + val x = AttributeReference("x", IntegerType)() + val y = AttributeReference("y", IntegerType)() + val z = AttributeReference("z", IntegerType)() + val testRelation = LocalRelation(a, b, c) + val testRelation2 = LocalRelation(x, y, z) + + private def hasOuterReferences(plan: LogicalPlan): Boolean = { + plan.find(_.expressions.exists(SubExprUtils.containsOuter)).isDefined + } + + private def check( + innerPlan: LogicalPlan, + outerPlan: LogicalPlan, + correctAnswer: LogicalPlan, + conditions: Seq[Expression]): Unit = { + val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan) + assert(!hasOuterReferences(outputPlan)) + comparePlans(outputPlan, correctAnswer) + assert(joinCond.length == conditions.length) + joinCond.zip(conditions).foreach(e => compareExpressions(e._1, e._2)) + } + + test("filter with correlated equality predicates only") { + val outerPlan = testRelation2 + val innerPlan = + Project(Seq(a, b), + Filter(OuterReference(x) === a, + testRelation)) + val correctAnswer = Project(Seq(a, b), testRelation) + check(innerPlan, outerPlan, correctAnswer, Seq(x === a)) + } + + test("filter with local and correlated equality predicates") { + val outerPlan = testRelation2 + val innerPlan = + Project(Seq(a, b), + Filter(And(OuterReference(x) === a, b === 3), + testRelation)) + val correctAnswer = + Project(Seq(a, b), + Filter(b === 3, + testRelation)) + check(innerPlan, outerPlan, correctAnswer, Seq(x === a)) + } + + test("filter with correlated non-equality predicates") { + val outerPlan = testRelation2 + val innerPlan = + Project(Seq(a, b), + Filter(OuterReference(x) > a, + testRelation)) + val correctAnswer = Project(Seq(a, b), testRelation) + check(innerPlan, outerPlan, correctAnswer, Seq(x > a)) + } + + test("duplicated output attributes") { + val outerPlan = testRelation + val innerPlan = + Project(Seq(a), + Filter(OuterReference(a) === a, + testRelation)) + val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan) + val a1 = outputPlan.output.head + val correctAnswer = + Project(Seq(Alias(a, a1.name)(a1.exprId)), + Project(Seq(a), + testRelation)) + comparePlans(outputPlan, correctAnswer) + assert(joinCond == Seq(a === a1)) + } + + test("filter with equality predicates with correlated values on both sides") { + val outerPlan = testRelation2 + val innerPlan = + Project(Seq(a), + Filter(OuterReference(x) === OuterReference(y) + b, + testRelation)) + val correctAnswer = Project(Seq(a, b), testRelation) + check(innerPlan, outerPlan, correctAnswer, Seq(x === y + b)) + } + + test("aggregate with correlated equality predicates - 1") { + val outerPlan = testRelation2 + val minB = Alias(min(b), "min_b")() + val innerPlan = + Aggregate(Nil, Seq(minB), + Filter(And(OuterReference(x) === a + c, b === 3), + testRelation)) + val correctAnswer = + Aggregate(Seq(a, c), Seq(minB, a, c), + Filter(b === 3, + testRelation)) + check(innerPlan, outerPlan, correctAnswer, Seq(x === a + c)) + } + + test("aggregate with correlated equality predicates - 2") { + val outerPlan = testRelation2 + val minB = Alias(min(b), "min_b")() + val innerPlan = + Aggregate(Nil, Seq(minB), + Filter(OuterReference(x) === OuterReference(y) + a, + testRelation)) + val correctAnswer = + Aggregate(Seq(a), Seq(minB, a), + testRelation) + check(innerPlan, outerPlan, correctAnswer, Seq(x === y + a)) + } + + test("aggregate with correlated equality predicates - 3") { + val outerPlan = testRelation2 + val minB = Alias(min(b), "min_b")() + val innerPlan = + Aggregate(Nil, Seq(minB), + Filter(OuterReference(x) === OuterReference(y), + testRelation)) + val correctAnswer = + Aggregate(Nil, Seq(minB), + testRelation) + check(innerPlan, outerPlan, correctAnswer, Seq(x === y)) + } + + test("aggregate with correlated non-equality predicates") { + val outerPlan = testRelation2 + val minB = Alias(min(b), "min_b")() + val innerPlan = + Aggregate(Nil, Seq(minB), + Filter(OuterReference(x) > a, + testRelation)) + val correctAnswer = + Aggregate(Seq(x), Seq(minB, x), + Filter(x > a, + DomainJoin(Seq(x), testRelation))) + check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x)) + } + + test("join with correlated equality predicates") { + val outerPlan = testRelation2 + val joinCondition = Some($"t1.b" === $"t2.b") + val left = + Project(Seq(b), + Filter(OuterReference(x) === b, + testRelation)).as("t1") + val right = + Project(Seq(b), + Filter(OuterReference(x) === a, + testRelation)).as("t2") + Seq(Inner, LeftOuter, LeftSemi, LeftAnti, RightOuter, FullOuter, Cross).foreach { joinType => + val innerPlan = Join(left, right, joinType, joinCondition, JoinHint.NONE).analyze + val newLeft = Project(Seq(b), testRelation).as("t1") + val newRight = Project(Seq(b, a), testRelation).as("t2") + // Since the left-hand side has outer(x) = b, and the right-hand side has outer(x) = a, the + // join condition will be augmented with b <=> a. + val newCond = Some(And($"t1.b" <=> $"t2.a", $"t1.b" === $"t2.b")) + val correctAnswer = Join(newLeft, newRight, joinType, newCond, JoinHint.NONE).analyze + check(innerPlan, outerPlan, correctAnswer, Seq(x === b, x === a)) + } + } + + test("correlated values inside join condition") { + val outerPlan = testRelation2 + val innerPlan = + Join( + testRelation.as("t1"), + Filter(OuterReference(y) === 3, testRelation), + Inner, + Some(OuterReference(x) === a), + JoinHint.NONE) + val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, outerPlan) } + assert(error.getMessage.contains("Correlated column is not allowed in join")) + } + + test("correlated values in project") { + val outerPlan = testRelation2 + val innerPlan = Project(Seq(OuterReference(x), OuterReference(y)), OneRowRelation()) + val correctAnswer = Project(Seq(x, y), DomainJoin(Seq(x, y), OneRowRelation())) + check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) + } + + test("correlated values in project with alias") { + val outerPlan = testRelation2 + val innerPlan = + Project(Seq(OuterReference(x), 'y1, 'sum), + Project(Seq( + OuterReference(x), + OuterReference(y).as("y1"), + Add(OuterReference(x), OuterReference(y)).as("sum")), + testRelation)).analyze + val correctAnswer = + Project(Seq(x, 'y1, 'sum, y), + Project(Seq(x, y.as("y1"), (x + y).as("sum"), y), + DomainJoin(Seq(x, y), testRelation))).analyze + check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) + } + + test("correlated values in project with correlated equality conditions in filter") { + val outerPlan = testRelation2 + val innerPlan = + Project( + Seq(OuterReference(x)), + Filter( + And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)), + testRelation + ) + ) + val correctAnswer = Project(Seq(a, c), Filter(b === 1, testRelation)) + check(innerPlan, outerPlan, correctAnswer, Seq(x === a, x + y === c)) + } + + test("correlated values in project without correlated equality conditions in filter") { + val outerPlan = testRelation2 + val innerPlan = + Project( + Seq(OuterReference(y)), + Filter( + And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)), + testRelation + ) + ) + val correctAnswer = + Project(Seq(y, a, c), + Filter(b === 1, + DomainJoin(Seq(y), testRelation) + ) + ) + check(innerPlan, outerPlan, correctAnswer, Seq(y <=> y, x === a, x + y === c)) + } + + test("correlated values in project with aggregate") { + val outerPlan = testRelation2 + val innerPlan = + Aggregate( + Seq('x1), Seq(min('y1).as("min_y1")), + Project( + Seq(a, OuterReference(x).as("x1"), OuterReference(y).as("y1")), + Filter( + And(OuterReference(x) === a, OuterReference(y) === OuterReference(z)), + testRelation + ) + ) + ).analyze + val correctAnswer = + Aggregate( + Seq('x1, y, a), Seq(min('y1).as("min_y1"), y, a), + Project( + Seq(a, a.as("x1"), y.as("y1"), y), + DomainJoin(Seq(y), testRelation) + ) + ).analyze + check(innerPlan, outerPlan, correctAnswer, Seq(y <=> y, x === a, y === z)) + } +} From 42664542d6005074f8b58a318e8716161ad5836f Mon Sep 17 00:00:00 2001 From: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Date: Tue, 13 Apr 2021 11:45:27 -0700 Subject: [PATCH 2/5] address comments --- .../sql/catalyst/expressions/subquery.scala | 6 +- .../optimizer/DecorrelateInnerQuery.scala | 371 ++++++++++++++++++ .../sql/catalyst/optimizer/subquery.scala | 298 -------------- .../plans/logical/basicLogicalOperators.scala | 2 + 4 files changed, 377 insertions(+), 300 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index dd24362aa80dd..2bedf84271585 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -180,10 +180,12 @@ object SubExprUtils extends PredicateHelper { * }}} * The code below needs to change when we support the above cases. */ - def getOuterReferences(condition: Expression): Seq[Expression] = { + def getOuterReferences(expr: Expression): Seq[Expression] = { val outerExpressions = ArrayBuffer.empty[Expression] - condition transformDown { + expr transformDown { case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) => + // Collect and update the sub-tree so that outer references inside this aggregate + // expression will not be collected. For example: min(outer(a)) -> min(a). val newExpr = stripOuterReference(a) outerExpressions += newExpr newExpr diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala new file mode 100644 index 0000000000000..9d7383dde6820 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ + +/** + * Decorrelate the inner query by eliminating outer references and create domain joins. + * The implementation is based on the paper: Unnesting Arbitrary Queries by Thomas Neumann + * and Alfons Kemper. https://dl.gi.de/handle/20.500.12116/2418. + * (1) Recursively collects outer references from the inner query until it reaches a node + * that does not contain correlated value. + * (2) Inserts an optional [[DomainJoin]] node to indicate whether a domain (inner) join is + * needed between the outer query and the specific subtree of the inner query. + * (3) Returns a list of join conditions with the outer query and a mapping between outer + * references with references inside the inner query. The parent nodes need to preserve + * the references inside the join conditions and substitute all outer references using + * the mapping. + * + * E.g. decorrelate an inner query with equality predicates: + * + * Aggregate [] [min(c2)] Aggregate [c1] [min(c2), c1] + * +- Filter [outer(c3) = c1] => +- Relation [t] + * +- Relation [t] + * + * Join conditions: [c3 = c1] + * + * E.g. decorrelate an inner query with non-equality predicates: + * + * Aggregate [] [min(c2)] Aggregate [c3'] [min(c2), c3'] + * +- Filter [outer(c3) > c1] => +- Filter [c3' > c1] + * +- Relation [t] +- DomainJoin [c3'] + * +- Relation [t] + * + * Join conditions: [c3 <=> c3'] + */ +object DecorrelateInnerQuery extends PredicateHelper { + + /** + * Check if the given expression is an equality condition. + */ + private def isEquality(expression: Expression): Boolean = expression match { + case Equality(_, _) => true + case _ => false + } + + /** + * Collect outer references in an expressions that are in the output attributes of the outer plan. + */ + private def collectOuterReferences(expression: Expression): AttributeSet = { + AttributeSet(expression.collect { case o: OuterReference => o.toAttribute }) + } + + /** + * Collect outer references in a sequence of expressions that are in the output attributes + * of the outer plan. + */ + private def collectOuterReferences(expressions: Seq[Expression]): AttributeSet = { + AttributeSet.fromAttributeSets(expressions.map(collectOuterReferences)) + } + + /** + * Build a mapping between outer references with equivalent inner query attributes. + * E.g. [outer(a) = x, y = outer(b), outer(c) = z + 1] => {a -> x, b -> y} + */ + private def collectEquivalentOuterReferences( + expressions: Seq[Expression]): Map[Attribute, Attribute] = { + expressions.collect { + case Equality(o: OuterReference, a: Attribute) => (o.toAttribute, a.toAttribute) + case Equality(a: Attribute, o: OuterReference) => (o.toAttribute, a.toAttribute) + }.toMap + } + + /** + * Replace all outer references using the expressions in the given outer reference map. + */ + private def replaceOuterReference[E <: Expression]( + expression: E, + outerReferenceMap: Map[Attribute, Attribute]): E = { + expression.transform { + case o: OuterReference => outerReferenceMap.getOrElse(o.toAttribute, o) + }.asInstanceOf[E] + } + + /** + * Replace all outer references in the given expressions using the expressions in the + * outer reference map. + */ + private def replaceOuterReferences[E <: Expression]( + expressions: Seq[E], + outerReferenceMap: Map[Attribute, Attribute]): Seq[E] = { + expressions.map(replaceOuterReference(_, outerReferenceMap)) + } + + /** + * Return all references that are presented in the join conditions but not in the output + * of the given named expressions. + */ + private def missingReferences( + namedExpressions: Seq[NamedExpression], + joinCond: Seq[Expression]): AttributeSet = { + val output = namedExpressions.map(_.toAttribute) + AttributeSet(joinCond.flatMap(_.references)) -- AttributeSet(output) + } + + /** + * Deduplicate the inner and the outer query attributes and return an aliased + * subquery plan and join conditions if duplicates are found. Duplicated attributes + * can break the structural integrity when joining the inner and outer plan together. + */ + def deduplicate( + innerPlan: LogicalPlan, + conditions: Seq[Expression], + outerOutputSet: AttributeSet): (LogicalPlan, Seq[Expression]) = { + val duplicates = innerPlan.outputSet.intersect(outerOutputSet) + if (duplicates.nonEmpty) { + val aliasMap = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = innerPlan.output.map { ref => + aliasMap.getOrElse(ref, ref) + } + val aliasedProjection = Project(aliasedExpressions, innerPlan) + val aliasedConditions = conditions.map(_.transform { + case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute + }) + (aliasedProjection, aliasedConditions) + } else { + (innerPlan, conditions) + } + } + + def apply( + innerPlan: LogicalPlan, + outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = { + apply(innerPlan, Seq(outerPlan)) + } + + def apply( + innerPlan: LogicalPlan, + outerPlans: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + val outputSet = AttributeSet(outerPlans.flatMap(_.outputSet)) + + // The return type of the recursion. + // The first parameter is a new logical plan with correlation eliminated. + // The second parameter is a list of join conditions with the outer query. + // The third parameter is a mapping between the outer references and equivalent + // expressions from the inner query that is used to replace outer references. + type ReturnType = (LogicalPlan, Seq[Expression], Map[Attribute, Attribute]) + + // Recursively decorrelate the input plan with a set of parent outer references and + // a boolean flag indicating whether the result of the plan will be aggregated. + def decorrelate( + plan: LogicalPlan, + parentOuterReferences: AttributeSet, + aggregated: Boolean = false): ReturnType = { + val isCorrelated = hasOuterReferences(plan) + if (!isCorrelated) { + // We have reached a plan without correlation to the outer plan. + if (parentOuterReferences.isEmpty) { + // If there is no outer references from the parent nodes, it means all outer + // attributes can be substituted by attributes from the inner plan. So no + // domain join is needed. + (plan, Nil, Map.empty[Attribute, Attribute]) + } else { + // Build the domain join with the parent outer references. + val attributes = parentOuterReferences.toSeq + val domains = attributes.map(_.newInstance()) + // A placeholder to be rewritten into domain join. + val domainJoin = DomainJoin(domains, plan) + val outerReferenceMap = attributes.zip(domains).toMap + // Build join conditions between domain attributes and outer references. + // EqualNullSafe is used to make sure null key can be joined together. Note + // outer referenced attributes can be changed during the outer query optimization. + // The equality conditions will also serve as an attribute mapping between new + // outer references and domain attributes when rewriting the domain joins. + // E.g. if the attribute a is changed to a1, the join condition a' <=> outer(a) + // will become a' <=> a1, and we can construct the aliases based on the condition: + // DomainJoin [a'] Join Inner + // +- InnerQuery => :- InnerQuery + // +- Aggregate [a1] [a1 AS a'] + // +- OuterQuery + val conditions = outerReferenceMap.map { + case (o, a) => EqualNullSafe(a, OuterReference(o)) + } + (domainJoin, conditions.toSeq, outerReferenceMap) + } + } else { + plan match { + case Filter(condition, child) => + val conditions = splitConjunctivePredicates(condition) + val (correlated, uncorrelated) = conditions.partition(containsOuter) + // Split the correlated predicates + val (equality, nonEquality) = correlated.partition(isEquality) + // Find outer references that can be substituted by attributes from the inner + // query using the equality predicates. + val equivalences = collectEquivalentOuterReferences(equality) + // Correlated predicates can be removed from the Filter's condition and used as + // join conditions with the outer query. However, if the results of the sub-tree + // is aggregated, only the correlated equality predicates can be used, because + // the inner query attributes from a non-equality predicate need to be preserved + // in both grouping and aggregate expressions, which can change the semantics + // of the plan and lead to incorrect results. Here is an example: + // Relations: + // t1(c1, c2): [(1, 1)] + // t2(c1, c2): [(1, 1), (2, 0)] + // + // Query: + // SELECT * FROM t1 WHERE c1 = (SELECT MAX(c1) FROM t2 WHERE t1.c2 >= c2) + // + // Subquery plan transformation if non-equality predicates are used as join conditions: + // Aggregate [max(c1)] Aggregate [c2] [max(c1), c2] + // +- Filter [outer(c2) >= c2] => +- Relation [c1, c2] + // +- Relation [c1, c2] + // + // Which will be rewritten to this query: + // SELECT c1, c2 FROM t1 LEFT OUTER JOIN + // (SELECT MAX(c1) m, c2 FROM t2 GROUP BY c2) s ON t1.c2 >= s.c2 WHERE c1 = m + // + // The result of the original query should be an empty set but the transformed + // query will output an incorrect result of (1, 1). The correct transformation + // is illustrated below: + // Aggregate [max(c1)] Aggregate [c2'] [max(c1), c2'] + // +- Filter [outer(c2) >= c2] => +- Filter [c2' >= c2] + // +- Relation [c1, c2] +- DomainJoin [c2'] + // +- Relation [c1, c2] + // Which will be rewritten to this query (using CTE here to make the query clearer): + // WITH domain AS ( -- [(1, 1)] + // SELECT DISTINCT c2 FROM t1 + // ), domainJoin AS ( -- [(1, 1, 1), (2, 0, 1)] + // SELECT t2.c1, t2.c2, domain.c2 AS dc2 FROM t2 JOIN domain + // ), subquery AS ( -- [(2, 1)] + // SELECT MAX(c1) m, dc2 FROM domainJoin WHERE dc2 >= c2 GROUP BY dc2 + // ) + // SELECT c1, c2 FROM t1 LEFT OUTER JOIN subquery ON c2 <=> dc2 WHERE c1 = m + if (aggregated) { + val outerReferences = collectOuterReferences(nonEquality) + val newOuterReferences = + parentOuterReferences ++ outerReferences -- equivalences.keySet + val (newChild, joinCond, outerReferenceMap) = + decorrelate(child, newOuterReferences, aggregated) + // Add the outer references mapping collected from the equality conditions. + val newOuterReferenceMap = outerReferenceMap ++ equivalences + // Replace all outer references in the non-equality predicates. + val nonEqualityCond = replaceOuterReferences(nonEquality, newOuterReferenceMap) + // The new filter condition is the original filter condition with correlated + // equality predicates removed. + val newFilterCond = nonEqualityCond ++ uncorrelated + val newFilter = newFilterCond match { + case Nil => newChild + case conditions => Filter(conditions.reduce(And), newChild) + } + // Equality predicates are used as join conditions with the outer query. + val newJoinCond = joinCond ++ equality + (newFilter, newJoinCond, newOuterReferenceMap) + } else { + // Results of this sub-tree is not aggregated, so all correlated predicates + // can be directly used as outer query join conditions. + val newOuterReferences = parentOuterReferences -- equivalences.keySet + val (newChild, joinCond, outerReferenceMap) = + decorrelate(child, newOuterReferences, aggregated) + // Add the outer references mapping collected from the equality conditions. + val newOuterReferenceMap = outerReferenceMap ++ equivalences + val newFilter = uncorrelated match { + case Nil => newChild + case conditions => Filter(conditions.reduce(And), newChild) + } + val newJoinCond = joinCond ++ correlated + (newFilter, newJoinCond, newOuterReferenceMap) + } + + case Project(projectList, child) => + val outerReferences = collectOuterReferences(projectList) + val newOuterReferences = parentOuterReferences ++ outerReferences + val (newChild, joinCond, outerReferenceMap) = + decorrelate(child, newOuterReferences, aggregated) + // Replace all outer references in the original project list. + val newProjectList = replaceOuterReferences(projectList, outerReferenceMap) + // Preserve required domain attributes in the join condition by adding the missing + // references to the new project list. + val referencesToAdd = missingReferences(newProjectList, joinCond) + val newProject = Project(newProjectList ++ referencesToAdd, newChild) + (newProject, joinCond, outerReferenceMap) + + case a @ Aggregate(groupingExpressions, aggregateExpressions, child) => + val outerReferences = collectOuterReferences(a.expressions) + val newOuterReferences = parentOuterReferences ++ outerReferences + val (newChild, joinCond, outerReferenceMap) = + decorrelate(child, newOuterReferences, aggregated = true) + // Replace all outer references in grouping and aggregate expressions. + val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap) + val newAggExpr = replaceOuterReferences(aggregateExpressions, outerReferenceMap) + // Add all required domain attributes to both grouping and aggregate expressions. + val referencesToAdd = missingReferences(newAggExpr, joinCond) + val newAggregate = a.copy( + groupingExpressions = newGroupingExpr ++ referencesToAdd, + aggregateExpressions = newAggExpr ++ referencesToAdd, + child = newChild) + (newAggregate, joinCond, outerReferenceMap) + + case j @ Join(left, right, joinType, condition, _) => + val outerReferences = collectOuterReferences(j.expressions) + // Join condition containing outer references is not supported. + assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j") + val newOuterReferences = parentOuterReferences ++ outerReferences + val shouldPushToLeft = joinType match { + case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true + case _ => hasOuterReferences(left) + } + val shouldPushToRight = joinType match { + case RightOuter | FullOuter => true + case _ => hasOuterReferences(right) + } + val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) { + decorrelate(left, newOuterReferences, aggregated) + } else { + (left, Nil, Map.empty[Attribute, Attribute]) + } + val (newRight, rightJoinCond, rightOuterReferenceMap) = if (shouldPushToRight) { + decorrelate(right, newOuterReferences, aggregated) + } else { + (right, Nil, Map.empty[Attribute, Attribute]) + } + val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap + val newJoinCond = leftJoinCond ++ rightJoinCond + // If we push the dependent join to both sides, we can augment the join condition + // such that both sides are matched on the domain attributes. For example, + // - Left Map: {outer(c1) = c1} + // - Right Map: {outer(c1) = 10 - c1} + // Then the join condition can be augmented with (c1 <=> 10 - c1). + val augmentedConditions = leftOuterReferenceMap.flatMap { + case (outer, inner) => rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _)) + } + val newCondition = (condition ++ augmentedConditions).reduceOption(And) + val newJoin = j.copy(left = newLeft, right = newRight, condition = newCondition) + (newJoin, newJoinCond, newOuterReferenceMap) + + case u: UnaryNode => + val outerReferences = collectOuterReferences(u.expressions) + assert(outerReferences.isEmpty, s"Correlated column is not allowed in $u") + decorrelate(u.child, parentOuterReferences, aggregated) + + case o => + throw new UnsupportedOperationException( + s"Decorrelate inner query through ${o.nodeName} is not supported.") + } + } + } + val (newChild, joinCond, _) = decorrelate(BooleanSimplification(innerPlan), AttributeSet.empty) + val (plan, conditions) = deduplicate(newChild, joinCond, outputSet) + (plan, stripOuterReferences(conditions)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index aa06004821ac1..48f2cf8e72f3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -724,301 +724,3 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe } } } - -/** - * Decorrelate the inner query by eliminating outer references and create domain joins. - * The implementation is based on the paper: Unnesting Arbitrary Queries by Thomas Neumann - * and Alfons Kemper. https://dl.gi.de/handle/20.500.12116/2418. - * (1) Recursively collects outer references from the inner query until it reaches a node - * that does not contain correlated value. - * (2) Inserts an optional [[DomainJoin]] node to indicate whether a domain (inner) join is - * needed between the outer query and the specific subtree of the inner query. - * (3) Returns a list of join conditions with the outer query and a mapping between outer - * references with references inside the inner query. The parent nodes need to preserve - * the references inside the join conditions and substitute all outer references using - * the mapping. - * - * E.g. decorrelate an inner query with equality predicates: - * - * Aggregate [] [min(c2)] Aggregate [c1] [min(c2), c1] - * +- Filter [outer(c3) = c1] => +- Relation [t] - * +- Relation [t] - * - * Join conditions: [c3 = c1] - * - * E.g. decorrelate an inner query with non-equality predicates: - * - * Aggregate [] [min(c2)] Aggregate [c3'] [min(c2), c3'] - * +- Filter [outer(c3) > c1] => +- Filter [c3' > c1] - * +- Relation [t] +- DomainJoin [c3'] - * +- Relation [t] - * - * Join conditions: [c3 <=> c3'] - */ -object DecorrelateInnerQuery extends PredicateHelper { - - /** - * Check if the given expression is an equality condition. - */ - private def isEquality(expression: Expression): Boolean = expression match { - case Equality(_, _) => true - case _ => false - } - - /** - * Collect outer references in an expressions that are in the output attributes of the outer plan. - */ - private def collectOuterReferences(expression: Expression): AttributeSet = { - AttributeSet(expression.collect { case o: OuterReference => o.toAttribute }) - } - - /** - * Collect outer references in a sequence of expressions that are in the output attributes - * of the outer plan. - */ - private def collectOuterReferences(expressions: Seq[Expression]): AttributeSet = { - AttributeSet.fromAttributeSets(expressions.map(collectOuterReferences)) - } - - /** - * Build a mapping between outer references with equivalent inner query attributes. - * E.g. [outer(a) = x, y = outer(b), outer(c) = z + 1] => {a -> x, b -> y} - */ - private def collectEquivalentOuterReferences( - expressions: Seq[Expression]): Map[Attribute, Attribute] = { - expressions.collect { - case Equality(o: OuterReference, a: Attribute) => (o.toAttribute, a.toAttribute) - case Equality(a: Attribute, o: OuterReference) => (o.toAttribute, a.toAttribute) - }.toMap - } - - /** - * Replace all outer references using the expressions in the given outer reference map. - */ - private def replaceOuterReference[E <: Expression]( - expression: E, - outerReferenceMap: Map[Attribute, Attribute]): E = { - expression.transform { - case o: OuterReference => outerReferenceMap.getOrElse(o.toAttribute, o) - }.asInstanceOf[E] - } - - /** - * Replace all outer references in the given expressions using the expressions in the - * outer reference map. - */ - private def replaceOuterReferences[E <: Expression]( - expressions: Seq[E], - outerReferenceMap: Map[Attribute, Attribute]): Seq[E] = { - expressions.map(replaceOuterReference(_, outerReferenceMap)) - } - - /** - * Return all missing references of the attribute set from the required attributes - * in the join condition. - */ - private def missingReferences( - expressions: Seq[Expression], - joinCond: Seq[Expression]): AttributeSet = { - val outputSet = AttributeSet(expressions) - AttributeSet(joinCond.flatMap(_.references)) -- outputSet - } - - /** - * Deduplicate the inner and the outer query attributes and return an aliased - * subquery plan and join conditions if duplicates are found. Duplicated attributes - * can break the structural integrity when joining the inner and outer plan together. - */ - def deduplicate( - innerPlan: LogicalPlan, - conditions: Seq[Expression], - outerOutputSet: AttributeSet): (LogicalPlan, Seq[Expression]) = { - val duplicates = innerPlan.outputSet.intersect(outerOutputSet) - if (duplicates.nonEmpty) { - val aliasMap = AttributeMap(duplicates.map { dup => - dup -> Alias(dup, dup.toString)() - }.toSeq) - val aliasedExpressions = innerPlan.output.map { ref => - aliasMap.getOrElse(ref, ref) - } - val aliasedProjection = Project(aliasedExpressions, innerPlan) - val aliasedConditions = conditions.map(_.transform { - case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute - }) - (aliasedProjection, aliasedConditions) - } else { - (innerPlan, conditions) - } - } - - def apply( - innerPlan: LogicalPlan, - outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = { - apply(innerPlan, Seq(outerPlan)) - } - - def apply( - innerPlan: LogicalPlan, - outerPlans: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { - val outputSet = AttributeSet(outerPlans.flatMap(_.outputSet)) - - // The return type of the recursion. - // The first parameter is a new logical plan with correlation eliminated. - // The second parameter is a list of join conditions with the outer query. - // The third parameter is a mapping between the outer references and equivalent - // expressions from the inner query that is used to replace outer references. - type ReturnType = (LogicalPlan, Seq[Expression], Map[Attribute, Attribute]) - - // Recursively decorrelate the input plan with a set of parent outer references and - // a boolean flag indicating whether the result of the plan will be aggregated. - def decorrelate( - plan: LogicalPlan, - parentOuterReferences: AttributeSet, - aggregated: Boolean = false): ReturnType = { - val isCorrelated = hasOuterReferences(plan) - if (!isCorrelated) { - // We have reached a plan without correlation to the outer plan. - if (parentOuterReferences.isEmpty) { - // If there is no outer references from the parent nodes, it means all outer - // attributes can be substituted by attributes from the inner plan. So no - // domain join is needed. - (plan, Nil, Map.empty[Attribute, Attribute]) - } else { - // Build the domain join with the parent outer references. - val attributes = parentOuterReferences.toSeq - val domains = attributes.map(_.newInstance()) - // A placeholder to be rewritten into domain join. - val domainJoin = DomainJoin(domains, plan) - val outerReferenceMap = attributes.zip(domains).toMap - // Build join conditions between domain attributes and outer references. - // EqualNullSafe is used to make sure null key can be joined together. Note - // outer referenced attributes can be changed during the outer query optimization. - // The equality conditions will also serve as an attribute mapping between new - // outer references and domain attributes when rewriting the domain joins. - // E.g. if the attribute a is changed to a1, the join condition a' <=> outer(a) - // will become a' <=> a1, and we can construct the aliases based on the condition: - // DomainJoin [a'] Join Inner - // +- InnerQuery => :- InnerQuery - // +- Aggregate [a1] [a1 AS a'] - // +- OuterQuery - val conditions = outerReferenceMap.map { - case (o, a) => EqualNullSafe(a, OuterReference(o)) - } - (domainJoin, conditions.toSeq, outerReferenceMap) - } - } else { - // Collect outer references from the current node. - val outerReferences = collectOuterReferences(plan.expressions) - plan match { - case Filter(condition, child) => - val (correlated, uncorrelated) = - splitConjunctivePredicates(condition) - .partition(containsOuter) - val (equality, nonEquality) = correlated.partition(isEquality) - // Find equivalent outer reference relations and remove equivalent attributes from - // parentOuterReferences since they can be replaced directly by expressions - // inside the inner plan. - val equivalences = collectEquivalentOuterReferences(equality) - // When the results are aggregated, outer references inside the non-equality - // predicates cannot be used directly as join conditions with the outer query. - val outerReferences = if (aggregated) { - collectOuterReferences(nonEquality) - } else { - AttributeSet.empty - } - val newOuterReferences = parentOuterReferences ++ outerReferences -- equivalences.keySet - val (newChild, joinCond, outerReferenceMap) = - decorrelate(child, newOuterReferences, aggregated) - // Add the mapping from the current node. - val newOuterReferenceMap = outerReferenceMap ++ equivalences - // Replace all outer references in non-equality filter conditions using the domain - // attributes produced for inner query with aggregates. This step is necessary - // for pushing down the non-equality filters into the domain join as join conditions. - val (newFilterCond, newJoinCond) = if (aggregated) { - val nonEqualityCond = replaceOuterReferences(nonEquality, newOuterReferenceMap) - (nonEqualityCond ++ uncorrelated, equality) - } else { - (uncorrelated, correlated) - } - val newFilter = newFilterCond match { - case Nil => newChild - case xs => Filter(xs.reduce(And), newChild) - } - (newFilter, joinCond ++ newJoinCond, newOuterReferenceMap) - - case Project(projectList, child) => - val newOuterReferences = parentOuterReferences ++ outerReferences - val (newChild, joinCond, outerReferenceMap) = - decorrelate(child, newOuterReferences, aggregated) - // Replace all outer references in the original project list. - val newProjectList = replaceOuterReferences(projectList, outerReferenceMap) - // Preserve required domain attributes in the join condition by adding the missing - // references to the new project list. - val referencesToAdd = missingReferences(newProjectList.map(_.toAttribute), joinCond) - val newProject = Project(newProjectList ++ referencesToAdd, newChild) - (newProject, joinCond, outerReferenceMap) - - case a @ Aggregate(groupingExpressions, aggregateExpressions, child) => - val newOuterReferences = parentOuterReferences ++ outerReferences - val (newChild, joinCond, outerReferenceMap) = - decorrelate(child, newOuterReferences, aggregated = true) - // Replace all outer references in grouping and aggregate expressions. - val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap) - val newAggExpr = replaceOuterReferences(aggregateExpressions, outerReferenceMap) - // Add all required domain attributes to both grouping and aggregate expressions. - val groupingExprToAdd = missingReferences(newGroupingExpr, joinCond) - val aggExprToAdd = missingReferences(newAggExpr.map(_.toAttribute), joinCond) - val newAggregate = a.copy( - groupingExpressions = newGroupingExpr ++ groupingExprToAdd, - aggregateExpressions = newAggExpr ++ aggExprToAdd, - child = newChild) - (newAggregate, joinCond, outerReferenceMap) - - case j @ Join(left, right, joinType, condition, _) => - // Join condition containing outer references is not supported. - assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j") - val newOuterReferences = parentOuterReferences ++ outerReferences - val shouldPushToLeft = joinType match { - case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true - case _ => hasOuterReferences(left) - } - val shouldPushToRight = joinType match { - case RightOuter | FullOuter => true - case _ => hasOuterReferences(right) - } - val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) { - decorrelate(left, newOuterReferences, aggregated) - } else { - (left, Nil, Map.empty[Attribute, Attribute]) - } - val (newRight, rightJoinCond, rightOuterReferenceMap) = if (shouldPushToRight) { - decorrelate(right, newOuterReferences, aggregated) - } else { - (right, Nil, Map.empty[Attribute, Attribute]) - } - val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap - val newJoinCond = leftJoinCond ++ rightJoinCond - // If we push the dependent join to both sides, we will need to augment the join - // condition such that both sides are matched on the domain attributes. - val augmentedConditions = leftOuterReferenceMap.flatMap { - case (outer, inner) => rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _)) - } - val newCondition = (condition ++ augmentedConditions).reduceOption(And) - val newJoin = j.copy(left = newLeft, right = newRight, condition = newCondition) - (newJoin, newJoinCond, newOuterReferenceMap) - - case s: UnaryNode => - assert(outerReferences.isEmpty, s"Correlated column is not allowed in $s") - decorrelate(s.child, parentOuterReferences, aggregated) - - case o => - throw new UnsupportedOperationException( - s"Push down dependent joins through $o is not supported.") - } - } - } - val (newChild, joinCond, _) = decorrelate(BooleanSimplification(innerPlan), AttributeSet.empty) - val (plan, conditions) = deduplicate(newChild, joinCond, outputSet) - (plan, stripOuterReferences(conditions)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 11e5482c5ecdd..49e3e3c9ee999 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1401,4 +1401,6 @@ case class CollectMetrics( case class DomainJoin(domainAttrs: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output ++ domainAttrs override def producedAttributes: AttributeSet = AttributeSet(domainAttrs) + override protected def withNewChildInternal(newChild: LogicalPlan): DomainJoin = + copy(child = newChild) } From d804c2292b3a352705da937e3c57a55b8a968c62 Mon Sep 17 00:00:00 2001 From: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Date: Thu, 15 Apr 2021 09:07:15 -0700 Subject: [PATCH 3/5] address comments --- .../optimizer/DecorrelateInnerQuery.scala | 169 ++++++++++++------ 1 file changed, 112 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala index 9d7383dde6820..546b21bdcc3b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -26,39 +26,81 @@ import org.apache.spark.sql.catalyst.plans.logical._ * Decorrelate the inner query by eliminating outer references and create domain joins. * The implementation is based on the paper: Unnesting Arbitrary Queries by Thomas Neumann * and Alfons Kemper. https://dl.gi.de/handle/20.500.12116/2418. - * (1) Recursively collects outer references from the inner query until it reaches a node - * that does not contain correlated value. - * (2) Inserts an optional [[DomainJoin]] node to indicate whether a domain (inner) join is - * needed between the outer query and the specific subtree of the inner query. - * (3) Returns a list of join conditions with the outer query and a mapping between outer - * references with references inside the inner query. The parent nodes need to preserve - * the references inside the join conditions and substitute all outer references using - * the mapping. + * + * A correlated subquery can be viewed as a "dependent" nested loop join between the outer and + * the inner query. For each row produced by the outer query, we bind the [[OuterReference]]s in + * in the inner query with the corresponding values in the row, and then evaluate the inner query. + * + * Dependent Join + * :- Outer Query + * +- Inner Query + * + * If the [[OuterReference]]s are bound to the same value, the inner query will return the same + * result. Based on this, we can reduce the times to evaluate the inner query by first getting + * all distinct values of the [[OuterReference]]s. + * + * Normal Join + * :- Outer Query + * +- Dependent Join + * :- Inner Query + * +- Distinct Aggregate (outer_ref1, outer_ref2, ...) + * +- Outer Query + * + * The distinct aggregate of the outer references is called a "domain", and the dependent join + * between the inner query and the domain is called a "domain join". We need to push down the + * domain join through the inner query until there is no outer reference in the sub-tree and + * the domain join will turn into a normal join. + * + * The decorrelation function returns a new query plan with optional placeholder [[DomainJoins]]s + * added and a list of join conditions with the outer query. [[DomainJoin]]s need to be rewritten + * into actual inner join between the inner query sub-tree and the outer query. * * E.g. decorrelate an inner query with equality predicates: * - * Aggregate [] [min(c2)] Aggregate [c1] [min(c2), c1] - * +- Filter [outer(c3) = c1] => +- Relation [t] - * +- Relation [t] + * SELECT (SELECT MIN(b) FROM t1 WHERE t2.c = t1.a) FROM t2 * - * Join conditions: [c3 = c1] + * Aggregate [] [min(b)] Aggregate [a] [min(b), a] + * +- Filter (outer(c) = a) => +- Relation [t1] + * +- Relation [t1] + * + * Join conditions: [c = a] * * E.g. decorrelate an inner query with non-equality predicates: * - * Aggregate [] [min(c2)] Aggregate [c3'] [min(c2), c3'] - * +- Filter [outer(c3) > c1] => +- Filter [c3' > c1] - * +- Relation [t] +- DomainJoin [c3'] - * +- Relation [t] + * SELECT (SELECT MIN(b) FROM t1 WHERE t2.c > t1.a) FROM t2 + * + * Aggregate [] [min(b)] Aggregate [c'] [min(b), c'] + * +- Filter (outer(c) > a) => +- Filter (c' > a) + * +- Relation [t1] +- DomainJoin [c'] + * +- Relation [t1] * - * Join conditions: [c3 <=> c3'] + * Join conditions: [c <=> c'] */ object DecorrelateInnerQuery extends PredicateHelper { /** - * Check if the given expression is an equality condition. + * Check if an expression contains any attribute. Note OuterReference is a + * leaf node and will not be found here. + */ + private def containsAttribute(expression: Expression): Boolean = { + expression.find(_.isInstanceOf[Attribute]).isDefined + } + + /** + * Check if an expression can be pulled up over an [[Aggregate]] without changing the + * semantics of the plan. The expression must be an equality predicate that guarantees + * one-to-one mapping between inner and outer attributes. More specifically, one side + * of the predicate must be an attribute and another side of the predicate must not + * contain other attributes from the inner query. + * For example: + * (a = outer(c)) -> true + * (a > outer(c)) -> false + * (a + b = outer(c)) -> false + * (a = outer(c) - b) -> false */ - private def isEquality(expression: Expression): Boolean = expression match { - case Equality(_, _) => true + private def canPullUpOverAgg(expression: Expression): Boolean = expression match { + case Equality(_: Attribute, b) => !containsAttribute(b) + case Equality(a, _: Attribute) => !containsAttribute(a) case _ => false } @@ -166,8 +208,16 @@ object DecorrelateInnerQuery extends PredicateHelper { // expressions from the inner query that is used to replace outer references. type ReturnType = (LogicalPlan, Seq[Expression], Map[Attribute, Attribute]) - // Recursively decorrelate the input plan with a set of parent outer references and - // a boolean flag indicating whether the result of the plan will be aggregated. + // Decorrelate the input plan with a set of parent outer references and a boolean flag + // indicating whether the result of the plan will be aggregated. Steps: + // 1. Recursively collects outer references from the inner query until it reaches a node + // that does not contain correlated value. + // 2. Inserts an optional [[DomainJoin]] node to indicate whether a domain (inner) join is + // needed between the outer query and the specific sub-tree of the inner query. + // 3. Returns a list of join conditions with the outer query and a mapping between outer + // references with references inside the inner query. The parent nodes need to preserve + // the references inside the join conditions and substitute all outer references using + // the mapping. def decorrelate( plan: LogicalPlan, parentOuterReferences: AttributeSet, @@ -208,51 +258,56 @@ object DecorrelateInnerQuery extends PredicateHelper { case Filter(condition, child) => val conditions = splitConjunctivePredicates(condition) val (correlated, uncorrelated) = conditions.partition(containsOuter) - // Split the correlated predicates - val (equality, nonEquality) = correlated.partition(isEquality) // Find outer references that can be substituted by attributes from the inner // query using the equality predicates. - val equivalences = collectEquivalentOuterReferences(equality) + val equivalences = collectEquivalentOuterReferences(correlated) // Correlated predicates can be removed from the Filter's condition and used as // join conditions with the outer query. However, if the results of the sub-tree - // is aggregated, only the correlated equality predicates can be used, because - // the inner query attributes from a non-equality predicate need to be preserved - // in both grouping and aggregate expressions, which can change the semantics - // of the plan and lead to incorrect results. Here is an example: + // is aggregated, only certain correlated equality predicates can be used, because + // the references in the join conditions need to be preserved in both the grouping + // and aggregate expressions of an Aggregate, which may change the semantics of the + // plan and lead to incorrect results. Here is an example: // Relations: - // t1(c1, c2): [(1, 1)] - // t2(c1, c2): [(1, 1), (2, 0)] + // t1(a, b): [(1, 1)] + // t2(c, d): [(1, 1), (2, 0)] // // Query: - // SELECT * FROM t1 WHERE c1 = (SELECT MAX(c1) FROM t2 WHERE t1.c2 >= c2) + // SELECT * FROM t1 WHERE a = (SELECT MAX(c) FROM t2 WHERE b >= d) // - // Subquery plan transformation if non-equality predicates are used as join conditions: - // Aggregate [max(c1)] Aggregate [c2] [max(c1), c2] - // +- Filter [outer(c2) >= c2] => +- Relation [c1, c2] - // +- Relation [c1, c2] + // Subquery plan transformation if correlated predicates are used as join conditions: + // Aggregate [max(c)] Aggregate [d] [max(c), d] + // +- Filter (outer(b) >= d) => +- Relation [c, d] + // +- Relation [c, d] // - // Which will be rewritten to this query: - // SELECT c1, c2 FROM t1 LEFT OUTER JOIN - // (SELECT MAX(c1) m, c2 FROM t2 GROUP BY c2) s ON t1.c2 >= s.c2 WHERE c1 = m + // Plan after rewrite: + // Project [a, b] -- [(1, 1)] + // +- Join LeftOuter (b >= d AND a = max(c)) + // :- Relation [a, b] + // +- Aggregate [d] [max(c), d] -- [(1, 1), (2, 0)] + // +- Relation [c, d] // // The result of the original query should be an empty set but the transformed // query will output an incorrect result of (1, 1). The correct transformation - // is illustrated below: - // Aggregate [max(c1)] Aggregate [c2'] [max(c1), c2'] - // +- Filter [outer(c2) >= c2] => +- Filter [c2' >= c2] - // +- Relation [c1, c2] +- DomainJoin [c2'] - // +- Relation [c1, c2] - // Which will be rewritten to this query (using CTE here to make the query clearer): - // WITH domain AS ( -- [(1, 1)] - // SELECT DISTINCT c2 FROM t1 - // ), domainJoin AS ( -- [(1, 1, 1), (2, 0, 1)] - // SELECT t2.c1, t2.c2, domain.c2 AS dc2 FROM t2 JOIN domain - // ), subquery AS ( -- [(2, 1)] - // SELECT MAX(c1) m, dc2 FROM domainJoin WHERE dc2 >= c2 GROUP BY dc2 - // ) - // SELECT c1, c2 FROM t1 LEFT OUTER JOIN subquery ON c2 <=> dc2 WHERE c1 = m + // with domain join is illustrated below: + // Aggregate [max(c)] Aggregate [b'] [max(c), b'] + // +- Filter (outer(b) >= d) => +- Filter (b' >= d) + // +- Relation [c, d] +- DomainJoin [b'] + // +- Relation [c, d] + // Plan after rewrite: + // Project [a, b] + // +- Join LeftOuter (b <=> b' AND a = max(c)) -- [] + // :- Relation [a, b] + // +- Aggregate [b'] [max(c), b'] -- [(2, 1)] + // +- Join Inner (b' >= d) -- [(1, 1, 1), (2, 0, 1)] (DomainJoin) + // :- Relation [c, d] + // +- Aggregate [b] [b AS b'] -- [(1)] (Domain) + // +- Relation [a, b] if (aggregated) { - val outerReferences = collectOuterReferences(nonEquality) + // Split the correlated predicates into predicates that can and cannot be directly + // used as join conditions with the outer query depending on whether they can + // be pulled up over an Aggregate without changing the semantics of the plan. + val (equalityCond, predicates) = correlated.partition(canPullUpOverAgg) + val outerReferences = collectOuterReferences(predicates) val newOuterReferences = parentOuterReferences ++ outerReferences -- equivalences.keySet val (newChild, joinCond, outerReferenceMap) = @@ -260,16 +315,16 @@ object DecorrelateInnerQuery extends PredicateHelper { // Add the outer references mapping collected from the equality conditions. val newOuterReferenceMap = outerReferenceMap ++ equivalences // Replace all outer references in the non-equality predicates. - val nonEqualityCond = replaceOuterReferences(nonEquality, newOuterReferenceMap) + val newCorrelated = replaceOuterReferences(predicates, newOuterReferenceMap) // The new filter condition is the original filter condition with correlated // equality predicates removed. - val newFilterCond = nonEqualityCond ++ uncorrelated + val newFilterCond = newCorrelated ++ uncorrelated val newFilter = newFilterCond match { case Nil => newChild case conditions => Filter(conditions.reduce(And), newChild) } // Equality predicates are used as join conditions with the outer query. - val newJoinCond = joinCond ++ equality + val newJoinCond = joinCond ++ equalityCond (newFilter, newJoinCond, newOuterReferenceMap) } else { // Results of this sub-tree is not aggregated, so all correlated predicates From c34e7009640bf89ff5a05380b0d35570b76b8cc2 Mon Sep 17 00:00:00 2001 From: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Date: Thu, 15 Apr 2021 10:29:42 -0700 Subject: [PATCH 4/5] update tests and small refactor --- .../optimizer/DecorrelateInnerQuery.scala | 60 ++++++++++++++++++- .../sql/catalyst/optimizer/subquery.scala | 56 +---------------- .../DecorrelateInnerQuerySuite.scala | 19 +++--- 3 files changed, 72 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala index 546b21bdcc3b4..377dcd6666d64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -101,7 +101,7 @@ object DecorrelateInnerQuery extends PredicateHelper { private def canPullUpOverAgg(expression: Expression): Boolean = expression match { case Equality(_: Attribute, b) => !containsAttribute(b) case Equality(a, _: Attribute) => !containsAttribute(a) - case _ => false + case o => !containsAttribute(o) } /** @@ -190,6 +190,64 @@ object DecorrelateInnerQuery extends PredicateHelper { } } + /** + * Build a mapping between domain attributes and corresponding outer query expressions + * using the join conditions. + */ + private def buildDomainAttrMap( + conditions: Seq[Expression], + domainAttrs: Seq[Attribute]): Map[Attribute, Expression] = { + val domainAttrSet = AttributeSet(domainAttrs) + conditions.collect { + // When we build the join conditions between the domain attributes and outer references, + // the left hand side is always the domain attribute used in the inner query and the right + // hand side is the attribute from the outer query. Note here the right hand side of a + // condition is not necessarily an attribute, for example it can be a literal (if foldable) + // or a cast expression after the optimization. + case EqualNullSafe(left: Attribute, right: Expression) if domainAttrSet.contains(left) => + left -> right + }.toMap + } + + /** + * Rewrite all [[DomainJoin]]s in the inner query to actual inner joins with the outer query. + */ + def rewriteDomainJoins( + outerPlan: LogicalPlan, + innerPlan: LogicalPlan, + conditions: Seq[Expression]): LogicalPlan = { + innerPlan transform { + case d @ DomainJoin(domainAttrs, child) => + val domainAttrMap = buildDomainAttrMap(conditions, domainAttrs) + // We should only rewrite a domain join when all corresponding outer plan attributes + // can be found from the join condition. + if (domainAttrMap.size == domainAttrs.size) { + val groupingExprs = domainAttrs.map(domainAttrMap) + val aggregateExprs = groupingExprs.zip(domainAttrs).map { + // Rebuild the aliases. + case (inputAttr, outputAttr) => Alias(inputAttr, outputAttr.name)(outputAttr.exprId) + } + // Construct a domain with the outer query plan. + // DomainJoin [a', b'] => Aggregate [a, b] [a AS a', b AS b'] + // +- Relation [a, b] + val domain = Aggregate(groupingExprs, aggregateExprs, outerPlan) + child match { + // A special optimization for OneRowRelation. + // TODO: add a more general rule to optimize join with OneRowRelation. + case _: OneRowRelation => domain + // Construct a domain join. + // Join Inner + // :- Inner Query + // +- Domain + case _ => Join(child, domain, Inner, None, JoinHint.NONE) + } + } else { + throw new UnsupportedOperationException( + s"Unable to rewrite domain join with conditions: $conditions\n$d") + } + } + } + def apply( innerPlan: LogicalPlan, outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 48f2cf8e72f3d..9381796d3d06b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -295,7 +295,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper if (newCond.isEmpty) oldCond else newCond } - def rewrite(sub: LogicalPlan, outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + def decorrelate(sub: LogicalPlan, outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { if (SQLConf.get.decorrelateInnerQueryEnabled) { DecorrelateInnerQuery(sub, outer) } else { @@ -305,7 +305,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper plan transformExpressions { case ScalarSubquery(sub, children, exprId) if children.nonEmpty => - val (newPlan, newCond) = rewrite(sub, outerPlans) + val (newPlan, newCond) = decorrelate(sub, outerPlans) ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId) case Exists(sub, children, exprId) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) @@ -509,56 +509,6 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe // Name of generated column used in rewrite below val ALWAYS_TRUE_COLNAME = "alwaysTrue" - /** - * Build a mapping between domain attributes and corresponding outer query expressions - * using the join conditions. - */ - private def buildDomainAttrMap( - conditions: Seq[Expression], - domainAttrs: Seq[Attribute]): Map[Attribute, Expression] = { - val outputSet = AttributeSet(domainAttrs) - conditions.collect { - // When we build the equality conditions, the left side is always the - // domain attributes used in the inner plan, and the right side is the - // attribute from outer plan. Note the right hand side is not necessarily - // an attribute, for example it can be a literal (if foldable) or a cast expression. - case EqualNullSafe(left: Attribute, right: Expression) if outputSet.contains(left) => - left -> right - }.toMap - } - - /** - * Rewrite domain join placeholder to actual inner joins. - */ - private def rewriteDomainJoins( - outerPlan: LogicalPlan, - innerPlan: LogicalPlan, - conditions: Seq[Expression]): LogicalPlan = { - innerPlan transform { - case d @ DomainJoin(domainAttrs, child) => - val domainAttrMap = buildDomainAttrMap(conditions, domainAttrs) - // We should only rewrite a domain join when all corresponding outer plan attributes - // can be found from the join condition. - if (domainAttrMap.size == domainAttrs.size) { - val groupingExprs = domainAttrs.map(domainAttrMap) - val aggregateExprs = groupingExprs.zip(domainAttrs).map { - // Rebuild the aliases. - case (inputAttr, outputAttr) => Alias(inputAttr, outputAttr.name)(outputAttr.exprId) - } - val domain = Aggregate(groupingExprs, aggregateExprs, outerPlan) - child match { - // A special optimization for OneRowRelation. - // TODO: add a more general rule to optimize join with OneRowRelation. - case _: OneRowRelation => domain - case _ => Join(child, domain, Inner, None, JoinHint.NONE) - } - } else { - throw new UnsupportedOperationException( - s"Unable to rewrite domain join with conditions: $conditions\n$d") - } - } - } - /** * Construct a new child plan by left joining the given subqueries to a base plan. * This method returns the child plan and an attribute mapping @@ -571,7 +521,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() val newChild = subqueries.foldLeft(child) { case (currentChild, ScalarSubquery(sub, conditions, _)) => - val query = rewriteDomainJoins(currentChild, sub, conditions) + val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions) val origOutput = query.output.head val resultWithZeroTups = evalSubqueryOnZeroTups(query) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala index f58e473728caf..93b27035aca33 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala @@ -110,21 +110,21 @@ class DecorrelateInnerQuerySuite extends PlanTest { check(innerPlan, outerPlan, correctAnswer, Seq(x === y + b)) } - test("aggregate with correlated equality predicates - 1") { + test("aggregate with correlated equality predicates that can be pulled up") { val outerPlan = testRelation2 val minB = Alias(min(b), "min_b")() val innerPlan = Aggregate(Nil, Seq(minB), - Filter(And(OuterReference(x) === a + c, b === 3), + Filter(And(OuterReference(x) === a, b === 3), testRelation)) val correctAnswer = - Aggregate(Seq(a, c), Seq(minB, a, c), + Aggregate(Seq(a), Seq(minB, a), Filter(b === 3, testRelation)) - check(innerPlan, outerPlan, correctAnswer, Seq(x === a + c)) + check(innerPlan, outerPlan, correctAnswer, Seq(x === a)) } - test("aggregate with correlated equality predicates - 2") { + test("aggregate with correlated equality predicates that cannot be pulled up") { val outerPlan = testRelation2 val minB = Alias(min(b), "min_b")() val innerPlan = @@ -132,12 +132,13 @@ class DecorrelateInnerQuerySuite extends PlanTest { Filter(OuterReference(x) === OuterReference(y) + a, testRelation)) val correctAnswer = - Aggregate(Seq(a), Seq(minB, a), - testRelation) - check(innerPlan, outerPlan, correctAnswer, Seq(x === y + a)) + Aggregate(Seq(x, y), Seq(minB, x, y), + Filter(x === y + a, + DomainJoin(Seq(x, y), testRelation))) + check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) } - test("aggregate with correlated equality predicates - 3") { + test("aggregate with correlated equality predicates that has no attribute") { val outerPlan = testRelation2 val minB = Alias(min(b), "min_b")() val innerPlan = From f665a0d52953ea7927d0d5dee7f4742255173691 Mon Sep 17 00:00:00 2001 From: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Date: Mon, 19 Apr 2021 21:00:57 -0700 Subject: [PATCH 5/5] retrigger tests