Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1922,15 +1922,9 @@ class Analyzer(
}
// We get an aggregate function, we need to wrap it in an AggregateExpression.
case agg: AggregateFunction =>
// TODO: SPARK-30276 Support Filter expression allows simultaneous use of DISTINCT
if (filter.isDefined) {
if (isDistinct) {
failAnalysis("DISTINCT and FILTER cannot be used in aggregate functions " +
"at the same time")
} else if (!filter.get.deterministic) {
failAnalysis("FILTER expression is non-deterministic, " +
"it cannot be used in aggregate functions")
}
if (filter.isDefined && !filter.get.deterministic) {
failAnalysis("FILTER expression is non-deterministic, " +
"it cannot be used in aggregate functions")
}
AggregateExpression(agg, Complete, isDistinct, filter)
// This function is not an aggregate function, just return the resolved one.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ package object dsl {
def count(e: Expression): Expression = Count(e).toAggregateExpression()
def countDistinct(e: Expression*): Expression =
Count(e).toAggregateExpression(isDistinct = true)
def countDistinct(filter: Option[Expression], e: Expression*): Expression =
Count(e).toAggregateExpression(isDistinct = true, filter = filter)
def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression =
HyperLogLogPlusPlus(e, rsd).toAggregateExpression()
def avg(e: Expression): Expression = Average(e).toAggregateExpression()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,21 @@ abstract class AggregateFunction extends Expression {
def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false)

/**
* Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and sets `isDistinct`
* flag of the [[AggregateExpression]] to the given value because
* Wraps this [[AggregateFunction]] in an [[AggregateExpression]] with `isDistinct`
* flag and an optional `filter` of the [[AggregateExpression]] to the given value because
* [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode,
* and the flag indicating if this aggregation is distinct aggregation or not.
* An [[AggregateFunction]] should not be used without being wrapped in
* the flag indicating if this aggregation is distinct aggregation or not and the optional
* `filter`. An [[AggregateFunction]] should not be used without being wrapped in
* an [[AggregateExpression]].
*/
def toAggregateExpression(isDistinct: Boolean): AggregateExpression = {
AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct)
def toAggregateExpression(
isDistinct: Boolean,
filter: Option[Expression] = None): AggregateExpression = {
AggregateExpression(
aggregateFunction = this,
mode = Complete,
isDistinct = isDistinct,
filter = filter)
}

def sql(isDistinct: Boolean): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteNonCorrelatedExists,
ComputeCurrentTime,
GetCurrentDatabaseAndCatalog(catalogManager),
ProjectFilterInAggregates,
RewriteDistinctAggregates,
ReplaceDeduplicateWithAggregate) ::
//////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -237,6 +238,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
ReplaceExpressions.ruleName ::
ComputeCurrentTime.ruleName ::
GetCurrentDatabaseAndCatalog(catalogManager).ruleName ::
ProjectFilterInAggregates.ruleName ::
RewriteDistinctAggregates.ruleName ::
ReplaceDeduplicateWithAggregate.ruleName ::
ReplaceIntersectWithSemiJoin.ruleName ::
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, If, IsNotNull, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule

/**
* If an aggregate query with filter clause, this rule will create a project node so as to filter
* the output of aggregate's child in advance.
*
* First example: query with filter clauses (in sql):
* {{{
* val data = Seq(
* (1, "a", "ca1", "cb1", 10),
* (2, "a", "ca1", "cb2", 5),
* (3, "b", "ca1", "cb1", 13))
* .toDF("id", "key", "cat1", "cat2", "value")
* data.createOrReplaceTempView("data")
*
* SELECT
* COUNT(DISTINCT cat1) AS cat1_cnt,
* COUNT(DISTINCT cat2) FILTER (WHERE id > 1) AS cat2_cnt,
* SUM(value) AS total,
* SUM(value) FILTER (WHERE key = "a") AS total2
* FROM
* data
* GROUP BY
* key
* }}}
*
* This translates to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [COUNT(DISTINCT 'cat1),
* COUNT(DISTINCT 'cat2) with FILTER('id > 1),
* SUM('value),
* SUM('value) with FILTER('key = "a")]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total, 'total2])
* LocalTableScan [...]
* }}}
*
* This rule rewrites this logical plan to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [COUNT(DISTINCT '_gen_attr_1),
* COUNT(DISTINCT '_gen_attr_2) with FILTER('_gen_attr_2 is not null),
* SUM('_gen_attr_3),
* SUM('_gen_attr_4) with FILTER('_gen_attr_4 is not null)]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total, 'total2])
* Project(
* projectionList = ['key,
* 'cat1,
* if ('id > 1) 'cat2 else null,
* cast('value as bigint),
* if ('key = "a") cast('value as bigint) else null]
* output = ['key, '_gen_attr_1, '_gen_attr_2, '_gen_attr_3, '_gen_attr_4])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cat1 is not related to the filter, why do we change its name to _gen_attr_1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For convenience and unification, we always alias the column, even if there is no filter.

* LocalTableScan [...]
* }}}
*
* The rule does the following things here:
* 1. Project the output of the child of the aggregate query. There are two aggregation
* groups in this query:
* i. the group without filter clause;
* ii. the group with filter clause;
* When there is at least one aggregate function having the filter clause, we add a project
* node on the input plan.
* 2. Avoid projections that may output the same attributes. There are three aggregation groups
* in this query:
* i. the non-distinct 'cat1 group without filter clause;
* ii. the distinct 'cat1 group without filter clause;
* iii. the distinct 'cat1 group with filter clause.
* The attributes referenced by different aggregate expressions are likely to overlap,
* and if no additional processing is performed, data loss will occur. If we directly output
* the attributes of the aggregate expression, we may get three attributes 'cat1. To prevent
* this, we generate new attributes (e.g. '_gen_attr_1) and replace the original ones.
*
* Why we need the first phase? guaranteed to compute filter clauses in the first aggregate
* locally.
* Note: after generate new attributes, the aggregate may have at least two distinct groups,
* so may trigger [[RewriteDistinctAggregates]].
*/
object ProjectFilterInAggregates extends Rule[LogicalPlan] {

private def collectAggregateExprs(exprs: Seq[Expression]): Seq[AggregateExpression] = {
exprs.flatMap { _.collect {
case ae: AggregateExpression => ae
}}
}

private def mayNeedtoProject(a: Aggregate): Boolean = {
if (collectAggregateExprs(a.aggregateExpressions).exists(_.filter.isDefined)) {
var flag = true
a resolveOperatorsUp {
case p: Project =>
if (p.output.exists(_.name.startsWith("_gen_attr_"))) {
flag = false
}
p
case other => other
}
flag
} else {
false
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case a: Aggregate if mayNeedtoProject(a) => project(a)
}

def project(a: Aggregate): Aggregate = {
val aggExpressions = collectAggregateExprs(a.aggregateExpressions)
// Constructs pairs between old and new expressions for aggregates.
val aggExprs = aggExpressions.filter(e => e.children.exists(!_.foldable))
var currentExprId = 0
val (projections, aggPairs) = aggExprs.map {
case ae @ AggregateExpression(af, _, _, filter, _) =>
// First, In order to reduce costs, it is better to handle the filter clause locally.
// e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression
// If(id > 1) 'a else null first, and use the result as output.
// Second, If at least two DISTINCT aggregate expression which may references the
// same attributes. We need to construct the generated attributes so as the output not
// lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id > 1) will output
// attribute '_gen_attr-1 and attribute '_gen_attr-2 instead of two 'a.
// Note: The illusionary mechanism may result in at least two distinct groups, so the
// RewriteDistinctAggregates may rewrite the logical plan.
val unfoldableChildren = af.children.filter(!_.foldable)
// Expand projection
val projectionMap = unfoldableChildren.map {
case e =>
currentExprId += 1
val ne = if (filter.isDefined) {
If(filter.get, e, Literal.create(null, e.dataType))
} else {
e
}
// For convenience and unification, we always alias the column, even if
// there is no filter.
e -> Alias(ne, s"_gen_attr_$currentExprId")()
}
val projection = projectionMap.map(_._2)
val exprAttrs = projectionMap.map { kv =>
(kv._1, kv._2.toAttribute)
}
val exprAttrLookup = exprAttrs.toMap
val newChildren = af.children.map(c => exprAttrLookup.getOrElse(c, c))
val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
val aggExpr = if (filter.isDefined) {
// When the filter execution result is false, the conditional expression will
// output null, it will affect the results of those aggregate functions not
// ignore nulls (e.g. count). So we add a new filter with IsNotNull.
ae.copy(aggregateFunction = raf, filter = Some(IsNotNull(newChildren.last)))
} else {
ae.copy(aggregateFunction = raf, filter = None)
}

(projection, (ae, aggExpr))
}.unzip
// Construct the aggregate input projection.
val namedGroupingProjection = a.groupingExpressions.flatMap { e =>
e.collect {
case ar: AttributeReference => ar
}
}
val rewriteAggProjection = namedGroupingProjection ++ projections.flatten
// Construct the project operator.
val project = Project(rewriteAggProjection, a.child)
val rewriteAggExprLookup = aggPairs.toMap
val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae)
}.asInstanceOf[NamedExpression]
}
Aggregate(a.groupingExpressions, patchedAggExpressions, project)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,10 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val naf = patchAggregateFunctionChildren(af) { x =>
distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _))
}
(e, e.copy(aggregateFunction = naf, isDistinct = false))
val filterOpt = e.filter.map(_.transform {
case a: Attribute => distinctAggChildAttrLookup.getOrElse(a, a)
})
(e, e.copy(aggregateFunction = naf, isDistinct = false, filter = filterOpt))
}

(projection, operators)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,6 @@ class AnalysisErrorSuite extends AnalysisTest {
"FILTER (WHERE c > 1)"),
"FILTER predicate specified, but aggregate is not an aggregate function" :: Nil)

errorTest(
"DISTINCT aggregate function with filter predicate",
CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"),
"DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil)

errorTest(
"non-deterministic filter predicate in aggregate functions",
CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class ProjectFilterInAggregatesSuite extends PlanTest {
override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
val analyzer = new Analyzer(catalog, conf)

val nullInt = Literal(null, IntegerType)
val nullString = Literal(null, StringType)
val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

private def checkGenerate(generate: LogicalPlan): Unit = generate match {
case Aggregate(_, _, _: Project) =>
case _ => fail(s"Plan is not generated:\n$generate")
}

test("single distinct group with filter") {
val input = testRelation
.groupBy('a)(countDistinct(Some(EqualTo('d, Literal(""))), 'e))
.analyze
checkGenerate(ProjectFilterInAggregates(input))
}

test("at least one distinct group with filter") {
val input = testRelation
.groupBy('a)(countDistinct(Some(EqualTo('d, Literal(""))), 'e), countDistinct('d))
.analyze
checkGenerate(ProjectFilterInAggregates(input))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,10 @@ object AggUtils {
// its input will have distinct arguments.
// We just keep the isDistinct setting to true, so when users look at the query plan,
// they still can see distinct aggregations.
val expr = AggregateExpression(func, Partial, isDistinct = true)
val filter = functionsWithDistinct(i).filter.map(_.transform {
case a: Attribute => distinctColumnAttributeLookup.getOrElse(a, a)
})
val expr = AggregateExpression(func, Partial, isDistinct = true, filter)
// Use original AggregationFunction to lookup attributes, which is used to build
// aggregateFunctionToAttribute
val attr = functionsWithDistinct(i).resultAttribute
Expand Down
Loading