From c4b88d5e74fd2bd3ba0f42a91359ac43c51cba11 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 7 Apr 2022 23:39:01 +0800 Subject: [PATCH] propagate distinct keys more precisely --- .../plans/logical/DistinctKeyVisitor.scala | 20 +++++++++++++++++-- .../logical/DistinctKeyVisitorSuite.scala | 7 ++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala index bb2bc4e3d2f93..726c52592887f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala @@ -50,11 +50,27 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] { }.filter(_.nonEmpty) } + /** + * Add a new ExpressionSet S into distinctKeys D. + * To minimize the size of D: + * 1. If there is a subset of S in D, return D. + * 2. Otherwise, remove all the ExpressionSet containing S from D, and add the new one. + */ + private def addDistinctKey( + keys: Set[ExpressionSet], + newExpressionSet: ExpressionSet): Set[ExpressionSet] = { + if (keys.exists(_.subsetOf(newExpressionSet))) { + keys + } else { + keys.filterNot(s => newExpressionSet.subsetOf(s)) + newExpressionSet + } + } + override def default(p: LogicalPlan): Set[ExpressionSet] = Set.empty[ExpressionSet] override def visitAggregate(p: Aggregate): Set[ExpressionSet] = { val groupingExps = ExpressionSet(p.groupingExpressions) // handle group by a, a - projectDistinctKeys(Set(groupingExps), p.aggregateExpressions) + projectDistinctKeys(addDistinctKey(p.child.distinctKeys, groupingExps), p.aggregateExpressions) } override def visitDistinct(p: Distinct): Set[ExpressionSet] = Set(ExpressionSet(p.output)) @@ -70,7 +86,7 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] { override def visitGlobalLimit(p: GlobalLimit): Set[ExpressionSet] = { p.maxRows match { - case Some(value) if value <= 1 => Set(ExpressionSet(p.output)) + case Some(value) if value <= 1 => p.output.map(attr => ExpressionSet(Seq(attr))).toSet case _ => p.child.distinctKeys } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala index b884b27fe3b08..80342f6dd7a78 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala @@ -66,6 +66,10 @@ class DistinctKeyVisitorSuite extends PlanTest { Set(ExpressionSet(Seq(a)), ExpressionSet(Seq(d.toAttribute)))) checkDistinctAttributes(t1.groupBy(f.child, $"b")(f, $"b", sum($"c")), Set(ExpressionSet(Seq(f.toAttribute, b)))) + + // Aggregate should also propagate distinct keys from child + checkDistinctAttributes(t1.limit(1).groupBy($"a", $"b")($"a", $"b"), + Set(ExpressionSet(Seq(a)), ExpressionSet(Seq(b)))) } test("Distinct's distinct attributes") { @@ -86,7 +90,8 @@ class DistinctKeyVisitorSuite extends PlanTest { test("Limit's distinct attributes") { checkDistinctAttributes(Distinct(t1).limit(10), Set(ExpressionSet(Seq(a, b, c)))) checkDistinctAttributes(LocalLimit(10, Distinct(t1)), Set(ExpressionSet(Seq(a, b, c)))) - checkDistinctAttributes(t1.limit(1), Set(ExpressionSet(Seq(a, b, c)))) + checkDistinctAttributes(t1.limit(1), + Set(ExpressionSet(Seq(a)), ExpressionSet(Seq(b)), ExpressionSet(Seq(c)))) } test("Intersect's distinct attributes") {