From d09d905180c0eb9d8b1ed9aaf5d207e2fac6547c Mon Sep 17 00:00:00 2001 From: Min Qiu Date: Tue, 1 Dec 2015 18:01:02 -0800 Subject: [PATCH 1/4] Extract the common equality conditions that can be used as a join condition --- .../sql/catalyst/optimizer/Optimizer.scala | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f4dba67f13b54..f390751f8e8c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,6 +40,8 @@ object DefaultOptimizer extends Optimizer { Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: + Batch("CNF factorization", FixedPoint(100), + ExtractEqualJoinCondition) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down SetOperationPushDown, @@ -911,3 +913,51 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { a.copy(groupingExpressions = newGrouping) } } + +/** + * Extracts the equal-join condition if any, so that query planner avoids generating cartsian + * product which cause out of memory exception, and performance issues + */ +object ExtractEqualJoinCondition extends Rule[LogicalPlan] with PredicateHelper{ + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Join(left, right, joinType, joinCondition) => + joinCondition match { + case Some(e) if isDNF(e) => { + val disjConditions = splitDisjunctivePredicates(e) + val exprMatrix = disjConditions.map(splitConjunctivePredicates) + if(exprMatrix.length <= 1) f + else { + val pattern = exprMatrix(0) + val comExprs: Seq[Expression] = pattern.filter(p => isCommonExpr(p, exprMatrix, 1)) + val newExprMatrix = exprMatrix.map(_.diff(comExprs)) + val newJoinCond = (comExprs :+ newExprMatrix.map(_.reduceLeft(And)).reduceLeft(Or)) + .reduceLeftOption(And) + Join(left, right, joinType, newJoinCond) + } + } + case _ => f + } + } + + def isCommonExpr(pattern: Expression, matrix: Seq[Seq[Expression]], startIndex: Int) : Boolean = { + val duplicatedCount = matrix.drop(startIndex).count(arr => arr.contains(pattern)) + return duplicatedCount == matrix.length - startIndex + } + + def isDNF(condition: Expression) : Boolean = { + condition match { + case Or(left, right) => isDNF(left) && isDNF(right) + case And(left, right) => isCNF(left) && isCNF(right) + case _ => true + } + } + + def isCNF(condition: Expression): Boolean = { + condition match { + case And(left, right) => isCNF(left) && isCNF(right) + case Or(left, right) => false + case _ => true + } + } +} + From 08a76aefcc6036d600fa92aee037515cce22ae09 Mon Sep 17 00:00:00 2001 From: Min Qiu Date: Wed, 2 Dec 2015 17:01:41 -0800 Subject: [PATCH 2/4] bug fix for ColumnPruning rule to deal with Project <- Filter <- Join case --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f390751f8e8c2..406e8426374b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -199,6 +199,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { * - Aggregate * - Generate * - Project <- Join + * - Project <- Filter <- Join * - LeftSemiJoin */ object ColumnPruning extends Rule[LogicalPlan] { @@ -248,6 +249,16 @@ object ColumnPruning extends Rule[LogicalPlan] { Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) + // Eliminate unneeded attributes from either side of a Join. + case Project(projectList, Filter(predicates, Join(left, right, joinType, condition))) => + val allReferences: AttributeSet = + AttributeSet( + projectList.flatMap(_.references.iterator)) ++ + predicates.references ++ + condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) + Project(projectList, Filter(predicates, Join( + prunedChild(left, allReferences), prunedChild(right, allReferences), joinType, condition))) + // Eliminate unneeded attributes from right side of a LeftSemiJoin. case Join(left, right, LeftSemi, condition) => // Collect the list of all references required to evaluate the condition. From b002b393124568f6d171e5619d40ff749b9b639d Mon Sep 17 00:00:00 2001 From: Min Qiu Date: Wed, 2 Dec 2015 18:04:54 -0800 Subject: [PATCH 3/4] bug fix for ColumnPruning rule to deal with Project <- Filter <- Join case --- .../sql/catalyst/optimizer/Optimizer.scala | 50 ------------------- 1 file changed, 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 406e8426374b0..098a5a8ee7154 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,8 +40,6 @@ object DefaultOptimizer extends Optimizer { Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: - Batch("CNF factorization", FixedPoint(100), - ExtractEqualJoinCondition) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down SetOperationPushDown, @@ -924,51 +922,3 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { a.copy(groupingExpressions = newGrouping) } } - -/** - * Extracts the equal-join condition if any, so that query planner avoids generating cartsian - * product which cause out of memory exception, and performance issues - */ -object ExtractEqualJoinCondition extends Rule[LogicalPlan] with PredicateHelper{ - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Join(left, right, joinType, joinCondition) => - joinCondition match { - case Some(e) if isDNF(e) => { - val disjConditions = splitDisjunctivePredicates(e) - val exprMatrix = disjConditions.map(splitConjunctivePredicates) - if(exprMatrix.length <= 1) f - else { - val pattern = exprMatrix(0) - val comExprs: Seq[Expression] = pattern.filter(p => isCommonExpr(p, exprMatrix, 1)) - val newExprMatrix = exprMatrix.map(_.diff(comExprs)) - val newJoinCond = (comExprs :+ newExprMatrix.map(_.reduceLeft(And)).reduceLeft(Or)) - .reduceLeftOption(And) - Join(left, right, joinType, newJoinCond) - } - } - case _ => f - } - } - - def isCommonExpr(pattern: Expression, matrix: Seq[Seq[Expression]], startIndex: Int) : Boolean = { - val duplicatedCount = matrix.drop(startIndex).count(arr => arr.contains(pattern)) - return duplicatedCount == matrix.length - startIndex - } - - def isDNF(condition: Expression) : Boolean = { - condition match { - case Or(left, right) => isDNF(left) && isDNF(right) - case And(left, right) => isCNF(left) && isCNF(right) - case _ => true - } - } - - def isCNF(condition: Expression): Boolean = { - condition match { - case And(left, right) => isCNF(left) && isCNF(right) - case Or(left, right) => false - case _ => true - } - } -} - From e22892dab78bfe1dcf3c9a482b3c9b7e64c0a119 Mon Sep 17 00:00:00 2001 From: Min Qiu Date: Thu, 10 Dec 2015 15:08:31 -0800 Subject: [PATCH 4/4] a new rule for adjusting the join order --- .../sql/catalyst/optimizer/Optimizer.scala | 110 ++++++++++++++++-- 1 file changed, 98 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 098a5a8ee7154..585d5605e49b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -66,7 +66,9 @@ object DefaultOptimizer extends Optimizer { Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), - ConvertToLocalRelation) :: Nil + ConvertToLocalRelation) :: + Batch("Join Order Adjustment", FixedPoint(100), + AdjustJoinOrderWithEqualConditions) :: Nil } /** @@ -197,7 +199,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { * - Aggregate * - Generate * - Project <- Join - * - Project <- Filter <- Join * - LeftSemiJoin */ object ColumnPruning extends Rule[LogicalPlan] { @@ -247,16 +248,6 @@ object ColumnPruning extends Rule[LogicalPlan] { Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) - // Eliminate unneeded attributes from either side of a Join. - case Project(projectList, Filter(predicates, Join(left, right, joinType, condition))) => - val allReferences: AttributeSet = - AttributeSet( - projectList.flatMap(_.references.iterator)) ++ - predicates.references ++ - condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - Project(projectList, Filter(predicates, Join( - prunedChild(left, allReferences), prunedChild(right, allReferences), joinType, condition))) - // Eliminate unneeded attributes from right side of a LeftSemiJoin. case Join(left, right, LeftSemi, condition) => // Collect the list of all references required to evaluate the condition. @@ -922,3 +913,98 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { a.copy(groupingExpressions = newGrouping) } } + +/** + * If there are equal-join conditions, but the join order prevents them from being seen + * by the optimizer, we will adjust the join order so that the join condition can be pushed + * down to join Operator. This avoids cartesian product in the physical plan + */ +object AdjustJoinOrderWithEqualConditions extends Rule[LogicalPlan] with PredicateHelper{ + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // only consider Inner Join + case f @ Filter(conds, join @ Join(leftPlan, rightPlan, joinType, _)) + if joinType == Inner => + val (joins, relations, joinConds) = splitJoinRelationsNodes(join) + val allFilterConds = splitConjunctivePredicates(conds) ++ joinConds + newOperator(f, allFilterConds, joins, relations) + case join @ Join(leftPlan, rightPlan, joinType, _) if joinType == Inner => + val (joins, relations, joinConds) = splitJoinRelationsNodes(join) + val allFilterConds = joinConds + newOperator(join, allFilterConds, joins, relations) + } + + def splitJoinRelationsNodes(join: Join) : (Seq[Join], Seq[LogicalPlan], Seq[Expression]) = { + var joins = new collection.mutable.ArrayBuffer[Join]() + var relations = new collection.mutable.ArrayBuffer[LogicalPlan]() + var joinConds = new collection.mutable.ArrayBuffer[Expression]() + var queue = new collection.mutable.ArrayBuffer[Join]() + queue += join + while(!queue.isEmpty){ + val curNode = queue(0) + joins += curNode + queue = queue.drop(1) + curNode.asInstanceOf[Join].condition match { + case Some(e) => joinConds ++= + splitConjunctivePredicates(curNode.asInstanceOf[Join].condition.get) + case None => joinConds ++= Seq.empty[Expression] + } + + if(curNode.left.isInstanceOf[Join] && curNode.left.asInstanceOf[Join].joinType == Inner){ + queue += curNode.left.asInstanceOf[Join] + } + else relations += curNode.left + if(curNode.right.isInstanceOf[Join] && curNode.right.asInstanceOf[Join].joinType == Inner){ + queue += curNode.right.asInstanceOf[Join] + } + else relations += curNode.right + } + (joins, relations, joinConds) + } + + def newOperator(plan: LogicalPlan, allFilterConds: Seq[Expression], + joins: Seq[Join], relations: Seq[LogicalPlan]) : LogicalPlan = { + val equalConds = allFilterConds.filter { + case EqualTo(l, r) => true + case _ => false + } + + if(joins.length <= 1 || joins.length + 1 < relations.length) plan + else { + if (allFilterConds.isEmpty) plan + else Filter(allFilterConds.reduceLeft(And), shiftJoinOrder(relations, equalConds)) + } + } + + def shiftJoinOrder(relations: Seq[LogicalPlan], equalConds: Seq[Expression]) : Join = { + var finished : Boolean = false + var index : Int = 0 + var relationsMap: Map[LogicalPlan, Boolean] = relations.map(r => (r -> true)).toMap + while(!finished){ + if (relationsMap.size == 1 || index == equalConds.length) { + finished = true + } + else { + val equalCond = equalConds(index) + val left = equalCond.asInstanceOf[EqualTo].left.references + val lj = relationsMap.keys.toSeq.find(r => left.size > 0 && left.subsetOf(r.outputSet)) + if(lj != None){ + val right = equalCond.asInstanceOf[EqualTo].right.references + val rj = relationsMap.keys.toSeq.find(r => right.size > 0 && right.subsetOf(r.outputSet)) + if(rj != None) { + if (!lj.get.fastEquals(rj.get)){ + relationsMap -= rj.get + relationsMap -= lj.get + relationsMap += (Join(lj.get, rj.get, Inner, None) -> true) + } + } + } + index += 1 + } + } + relationsMap.keys.toSeq.reduceLeft(combineJoin).asInstanceOf[Join] + } + + def combineJoin(left: LogicalPlan, right: LogicalPlan) : LogicalPlan = { + Join(left, right, Inner, None) + } +}