diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 66c2ad84ccee8..bb788336c6d77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -146,7 +146,7 @@ abstract class Optimizer(catalogManager: CatalogManager) PushDownPredicates) :: Nil } - val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) :: + val batches = ( // Technically some of the rules in Finish Analysis are not optimizer rules and belong more // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime). // However, because we also use the analyzer to canonicalized queries (for view definition), @@ -166,6 +166,7 @@ abstract class Optimizer(catalogManager: CatalogManager) ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// + Batch("Eliminate Distinct", Once, EliminateDistinct) :: // - Do the first call of CombineUnions before starting the major Optimizer rules, // since it can reduce the number of iteration and the other rules could add/move // extra operators between two adjacent Union operators. @@ -411,14 +412,26 @@ abstract class Optimizer(catalogManager: CatalogManager) } /** - * Remove useless DISTINCT for MAX and MIN. + * Remove useless DISTINCT: + * 1. For some aggregate expression, e.g.: MAX and MIN. + * 2. If the distinct semantics is guaranteed by child. + * * This rule should be applied before RewriteDistinctAggregates. */ object EliminateDistinct extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( - _.containsPattern(AGGREGATE_EXPRESSION)) { - case ae: AggregateExpression if ae.isDistinct && isDuplicateAgnostic(ae.aggregateFunction) => - ae.copy(isDistinct = false) + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(AGGREGATE)) { + case agg: Aggregate => + agg.transformExpressionsWithPruning(_.containsPattern(AGGREGATE_EXPRESSION)) { + case ae: AggregateExpression if ae.isDistinct && + isDuplicateAgnostic(ae.aggregateFunction) => + ae.copy(isDistinct = false) + + case ae: AggregateExpression if ae.isDistinct && + agg.child.distinctKeys.exists( + _.subsetOf(ExpressionSet(ae.aggregateFunction.children.filterNot(_.foldable)))) => + ae.copy(isDistinct = false) + } } def isDuplicateAgnostic(af: AggregateFunction): Boolean = af match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala index 1843c2da478ef..2ffa5a0e594e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala @@ -29,6 +29,12 @@ import org.apache.spark.sql.internal.SQLConf.PROPAGATE_DISTINCT_KEYS_ENABLED */ trait LogicalPlanDistinctKeys { self: LogicalPlan => lazy val distinctKeys: Set[ExpressionSet] = { - if (conf.getConf(PROPAGATE_DISTINCT_KEYS_ENABLED)) DistinctKeyVisitor.visit(self) else Set.empty + if (conf.getConf(PROPAGATE_DISTINCT_KEYS_ENABLED)) { + val keys = DistinctKeyVisitor.visit(self) + require(keys.forall(_.nonEmpty)) + keys + } else { + Set.empty + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala index 9c57ced8492b8..798cc0a42dd3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala @@ -33,6 +33,7 @@ class EliminateDistinctSuite extends PlanTest { } val testRelation = LocalRelation($"a".int) + val testRelation2 = LocalRelation($"a".int, $"b".string) Seq( Max(_), @@ -71,4 +72,21 @@ class EliminateDistinctSuite extends PlanTest { comparePlans(Optimize.execute(query), answer) } } + + test("SPARK-38832: Remove unnecessary distinct in aggregate expression by distinctKeys") { + val q1 = testRelation2.groupBy($"a")($"a") + .rebalance().groupBy()(countDistinct($"a") as "x", sumDistinct($"a") as "y").analyze + val r1 = testRelation2.groupBy($"a")($"a") + .rebalance().groupBy()(count($"a") as "x", sum($"a") as "y").analyze + comparePlans(Optimize.execute(q1), r1) + + // not a subset of distinct attr + val q2 = testRelation2.groupBy($"a", $"b")($"a", $"b") + .rebalance().groupBy()(countDistinct($"a") as "x", sumDistinct($"a") as "y").analyze + comparePlans(Optimize.execute(q2), q2) + + // child distinct key is empty + val q3 = testRelation2.groupBy($"a")(countDistinct($"a") as "x").analyze + comparePlans(Optimize.execute(q3), q3) + } }