diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 193082eb77024..98b692739618e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1197,11 +1197,6 @@ class Analyzer( case u @ UnresolvedFunction(funcId, children, isDistinct) => withPosition(u) { catalog.lookupFunction(funcId, children) match { - // DISTINCT is not meaningful for a Max or a Min. - case max: Max if isDistinct => - AggregateExpression(max, Complete, isDistinct = false) - case min: Min if isDistinct => - AggregateExpression(min, Complete, isDistinct = false) // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index beee93d906f0f..f6792569b704e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -159,7 +159,9 @@ package object dsl { def first(e: Expression): Expression = new First(e).toAggregateExpression() def last(e: Expression): Expression = new Last(e).toAggregateExpression() def min(e: Expression): Expression = Min(e).toAggregateExpression() + def minDistinct(e: Expression): Expression = Min(e).toAggregateExpression(isDistinct = true) def max(e: Expression): Expression = Max(e).toAggregateExpression() + def maxDistinct(e: Expression): Expression = Max(e).toAggregateExpression(isDistinct = true) def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b410312030c5d..946fa7bae0199 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,6 +40,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) protected val fixedPoint = FixedPoint(conf.optimizerMaxIterations) def batches: Seq[Batch] = { + Batch("Eliminate Distinct", Once, EliminateDistinct) :: // Technically some of the rules in Finish Analysis are not optimizer rules and belong more // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime). // However, because we also use the analyzer to canonicalized queries (for view definition), @@ -151,6 +152,20 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil } +/** + * Remove useless DISTINCT for MAX and MIN. + * This rule should be applied before RewriteDistinctAggregates. + */ +object EliminateDistinct extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformExpressions { + case ae: AggregateExpression if ae.isDistinct => + ae.aggregateFunction match { + case _: Max | _: Min => ae.copy(isDistinct = false) + case _ => ae + } + } +} + /** * An optimizer used in test code. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala new file mode 100644 index 0000000000000..f40691bd1a038 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class EliminateDistinctSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", Once, + EliminateDistinct) :: Nil + } + + val testRelation = LocalRelation('a.int) + + test("Eliminate Distinct in Max") { + val query = testRelation + .select(maxDistinct('a).as('result)) + .analyze + val answer = testRelation + .select(max('a).as('result)) + .analyze + assert(query != answer) + comparePlans(Optimize.execute(query), answer) + } + + test("Eliminate Distinct in Min") { + val query = testRelation + .select(minDistinct('a).as('result)) + .analyze + val answer = testRelation + .select(min('a).as('result)) + .analyze + assert(query != answer) + comparePlans(Optimize.execute(query), answer) + } +}