From a1cf288e01dd92150d158b96db2ffb292dfa2478 Mon Sep 17 00:00:00 2001 From: Prakhar Jain Date: Thu, 19 Nov 2020 15:52:55 +0530 Subject: [PATCH 1/5] Refactor SortOrder class --- .../apache/spark/sql/catalyst/expressions/SortOrder.scala | 4 +++- .../spark/sql/execution/AliasAwareOutputExpression.scala | 6 +----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 54259e713accd..c2e20e68ae089 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -64,7 +64,9 @@ case class SortOrder( direction: SortDirection, nullOrdering: NullOrdering, sameOrderExpressions: Set[Expression]) - extends UnaryExpression with Unevaluable { + extends Expression with Unevaluable { + + override def children: Seq[Expression] = child +: sameOrderExpressions.toSeq override def checkInputDataTypes(): TypeCheckResult = { if (RowOrdering.isOrderable(dataType)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala index 3ba8745be995f..3cbe1654ea2cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala @@ -65,11 +65,7 @@ trait AliasAwareOutputOrdering extends AliasAwareOutputExpression { final override def outputOrdering: Seq[SortOrder] = { if (hasAlias) { - orderingExpressions.map { sortOrder => - val newSortOrder = normalizeExpression(sortOrder).asInstanceOf[SortOrder] - val newSameOrderExpressions = newSortOrder.sameOrderExpressions.map(normalizeExpression) - newSortOrder.copy(sameOrderExpressions = newSameOrderExpressions) - } + orderingExpressions.map(normalizeExpression(_).asInstanceOf[SortOrder]) } else { orderingExpressions } From 9510b5fcf6a91efccd522460953d4c5f6e191480 Mon Sep 17 00:00:00 2001 From: Prakhar Jain Date: Fri, 20 Nov 2020 19:28:54 +0530 Subject: [PATCH 2/5] use children in sameOrder --- .../org/apache/spark/sql/catalyst/expressions/SortOrder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index c2e20e68ae089..f2e9f5d6954f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -85,7 +85,7 @@ case class SortOrder( def isAscending: Boolean = direction == Ascending def satisfies(required: SortOrder): Boolean = { - (sameOrderExpressions + child).exists(required.child.semanticEquals) && + children.exists(required.child.semanticEquals) && direction == required.direction && nullOrdering == required.nullOrdering } } From 1c38e979a35dbaafbac013254e3bc35befde7cc8 Mon Sep 17 00:00:00 2001 From: Prakhar Jain Date: Wed, 25 Nov 2020 20:13:40 +0530 Subject: [PATCH 3/5] convert same order expressions to Seq --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 4 +-- .../sql/catalyst/expressions/SortOrder.scala | 6 ++-- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 8 +++--- .../execution/joins/SortMergeJoinExec.scala | 9 +++--- .../spark/sql/execution/PlannerSuite.scala | 28 +++++++++++++++++++ 7 files changed, 44 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 837686420375a..5624ce289dbc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1817,7 +1817,7 @@ class Analyzer(override val catalogManager: CatalogManager) val newOrders = orders map { case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => if (index > 0 && index <= child.output.size) { - SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty) + SortOrder(child.output(index - 1), direction, nullOrdering, Seq.empty) } else { s.failAnalysis( s"ORDER BY position $index is not in select list " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 4cd649b07a5c0..3f875154e769e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -131,9 +131,9 @@ package object dsl { } def asc: SortOrder = SortOrder(expr, Ascending) - def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty) + def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Seq.empty) def desc: SortOrder = SortOrder(expr, Descending) - def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Set.empty) + def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Seq.empty) def as(alias: String): NamedExpression = Alias(expr, alias)() def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index f2e9f5d6954f2..d9923b5d022e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -63,10 +63,10 @@ case class SortOrder( child: Expression, direction: SortDirection, nullOrdering: NullOrdering, - sameOrderExpressions: Set[Expression]) + sameOrderExpressions: Seq[Expression]) extends Expression with Unevaluable { - override def children: Seq[Expression] = child +: sameOrderExpressions.toSeq + override def children: Seq[Expression] = child +: sameOrderExpressions override def checkInputDataTypes(): TypeCheckResult = { if (RowOrdering.isOrderable(dataType)) { @@ -94,7 +94,7 @@ object SortOrder { def apply( child: Expression, direction: SortDirection, - sameOrderExpressions: Set[Expression] = Set.empty): SortOrder = { + sameOrderExpressions: Seq[Expression] = Seq.empty): SortOrder = { new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5f8394c525949..7bc0a3c44b7ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1893,7 +1893,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } else { direction.defaultNullOrdering } - SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty) + SortOrder(expression(ctx.expression), direction, nullOrdering, Seq.empty) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c164835c753e8..c0216eb18719d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1226,7 +1226,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) } + def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Seq.empty) } /** * Returns a sort expression based on the descending order of the column, @@ -1242,7 +1242,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) } + def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Seq.empty) } /** * Returns a sort expression based on ascending order of the column. @@ -1273,7 +1273,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) } + def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Seq.empty) } /** * Returns a sort expression based on ascending order of the column, @@ -1289,7 +1289,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) } + def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Seq.empty) } /** * Prints the expression to the console for debugging purposes. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 6e59ad07d7168..ce3f11939170e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -68,9 +68,9 @@ case class SortMergeJoinExec( val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering) val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => - // Also add the right key and its `sameOrderExpressions` - SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey - .sameOrderExpressions) + // Also add expressions from right side sort order + val sameOrderExpressions = ExpressionSet(lKey.children ++ rKey.children) - lKey.child + SortOrder(lKey.child, Ascending, sameOrderExpressions.toSeq) } // For left and right outer joins, the output is ordered by the streamed input's join keys. case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering) @@ -96,7 +96,8 @@ case class SortMergeJoinExec( val requiredOrdering = requiredOrders(keys) if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) { keys.zip(childOutputOrdering).map { case (key, childOrder) => - SortOrder(key, Ascending, childOrder.sameOrderExpressions + childOrder.child - key) + val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key + SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq) } } else { requiredOrdering diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 6de81cc414d7d..b96d2067d7646 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -1090,6 +1090,34 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } } + test("sort order doesn't have repeated expressions") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempView("t1") { + withTempView("t2") { + spark.range(10).repartition($"id").createTempView("t1") + spark.range(20).repartition($"id").createTempView("t2") + val planned = sql( + """ + | SELECT t12.id, t1.id + | FROM (SELECT t1.id FROM t1, t2 WHERE t1.id * 2 = t2.id) t12, t1 + | where 2 * t12.id = t1.id + """.stripMargin).queryExecution.executedPlan + + // t12 is already sorted on `t1.id * 2`. and we need to sort it on `2 * t12.id` + // for 2nd join. So sorting on t12 can be avoided + val sortNodes = planned.collect { case s: SortExec => s } + assert(sortNodes.size == 3) + val outputOrdering = planned.outputOrdering + assert(outputOrdering.size == 1) + // Sort order should have 3 childrens, not 4. This is because t1.id*2 and 2*t1.id are same + assert(outputOrdering.head.children.size == 3) + assert(outputOrdering.head.children.count(_.isInstanceOf[AttributeReference]) == 2) + assert(outputOrdering.head.children.count(_.isInstanceOf[Multiply]) == 1) + } + } + } + } + test("aliases to expressions should not be replaced") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { withTempView("df1", "df2") { From b82540e7fbdee19d7b1215a76de09cbdf0b08793 Mon Sep 17 00:00:00 2001 From: Prakhar Jain Date: Fri, 27 Nov 2020 19:35:57 +0530 Subject: [PATCH 4/5] review comments addressed --- .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ce3f11939170e..eabbdc8ed3243 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -69,7 +69,7 @@ case class SortMergeJoinExec( val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => // Also add expressions from right side sort order - val sameOrderExpressions = ExpressionSet(lKey.children ++ rKey.children) - lKey.child + val sameOrderExpressions = ExpressionSet(lKey.sameOrderExpressions ++ rKey.children) SortOrder(lKey.child, Ascending, sameOrderExpressions.toSeq) } // For left and right outer joins, the output is ordered by the streamed input's join keys. From 1ac0d64ddbeaa88ac6eae42054bd55828097b408 Mon Sep 17 00:00:00 2001 From: Prakhar Jain Date: Tue, 1 Dec 2020 12:43:26 +0530 Subject: [PATCH 5/5] withTempView refactor --- .../spark/sql/execution/PlannerSuite.scala | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b96d2067d7646..5e30f846307ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -1092,28 +1092,26 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { test("sort order doesn't have repeated expressions") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - withTempView("t1") { - withTempView("t2") { - spark.range(10).repartition($"id").createTempView("t1") - spark.range(20).repartition($"id").createTempView("t2") - val planned = sql( - """ - | SELECT t12.id, t1.id - | FROM (SELECT t1.id FROM t1, t2 WHERE t1.id * 2 = t2.id) t12, t1 - | where 2 * t12.id = t1.id - """.stripMargin).queryExecution.executedPlan + withTempView("t1", "t2") { + spark.range(10).repartition($"id").createTempView("t1") + spark.range(20).repartition($"id").createTempView("t2") + val planned = sql( + """ + | SELECT t12.id, t1.id + | FROM (SELECT t1.id FROM t1, t2 WHERE t1.id * 2 = t2.id) t12, t1 + | where 2 * t12.id = t1.id + """.stripMargin).queryExecution.executedPlan - // t12 is already sorted on `t1.id * 2`. and we need to sort it on `2 * t12.id` - // for 2nd join. So sorting on t12 can be avoided - val sortNodes = planned.collect { case s: SortExec => s } - assert(sortNodes.size == 3) - val outputOrdering = planned.outputOrdering - assert(outputOrdering.size == 1) - // Sort order should have 3 childrens, not 4. This is because t1.id*2 and 2*t1.id are same - assert(outputOrdering.head.children.size == 3) - assert(outputOrdering.head.children.count(_.isInstanceOf[AttributeReference]) == 2) - assert(outputOrdering.head.children.count(_.isInstanceOf[Multiply]) == 1) - } + // t12 is already sorted on `t1.id * 2`. and we need to sort it on `2 * t12.id` + // for 2nd join. So sorting on t12 can be avoided + val sortNodes = planned.collect { case s: SortExec => s } + assert(sortNodes.size == 3) + val outputOrdering = planned.outputOrdering + assert(outputOrdering.size == 1) + // Sort order should have 3 childrens, not 4. This is because t1.id*2 and 2*t1.id are same + assert(outputOrdering.head.children.size == 3) + assert(outputOrdering.head.children.count(_.isInstanceOf[AttributeReference]) == 2) + assert(outputOrdering.head.children.count(_.isInstanceOf[Multiply]) == 1) } } }