Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ object Canonicalize {
case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply)

case o: Or =>
orderCommutative(o, { case Or(l, r) if l.deterministic && r.deterministic => Seq(l, r) })
orderCommutative(o, { case Or(l, r) if l.idempotent && r.idempotent => Seq(l, r) })
.reduce(Or)
case a: And =>
orderCommutative(a, { case And(l, r) if l.deterministic && r.deterministic => Seq(l, r)})
orderCommutative(a, { case And(l, r) if l.idempotent && r.idempotent => Seq(l, r)})
.reduce(And)

case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ abstract class Expression extends TreeNode[Expression] {
*/
lazy val deterministic: Boolean = children.forall(_.deterministic)

/**
* Returns true iff the current expression can produce other results apart from its evaluation.
* This is the case for expressions which throw exceptions in certain conditions.
* By default leaf expressions return false since Nil.exists(_.hasSideEffect) returns false.
*/
lazy val hasSideEffect: Boolean = children.exists(_.hasSideEffect)

final def idempotent: Boolean = deterministic && !hasSideEffect

def nullable: Boolean

def references: AttributeSet = AttributeSet.fromAttributeSets(children.map(_.references))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa

override def nullable: Boolean = true

override lazy val hasSideEffect: Boolean = true

override def inputTypes: Seq[DataType] = Seq(BooleanType)

override def dataType: DataType = NullType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
override def foldable: Boolean = false
override def nullable: Boolean = false

override lazy val hasSideEffect: Boolean = true

override def flatArguments: Iterator[Any] = Iterator(child)

private val errMsg = "Null value appeared in non-nullable field:" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper
// Push down deterministic projection through UNION ALL
case p @ Project(projectList, Union(children)) =>
assert(children.nonEmpty)
if (projectList.forall(_.deterministic)) {
if (projectList.forall(_.idempotent)) {
val newFirstChild = Project(projectList, children.head)
val newOtherChildren = children.tail.map { child =>
val rewrites = buildRewrites(children.head, child)
Expand Down Expand Up @@ -649,13 +649,13 @@ object CollapseProject extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p1 @ Project(_, p2: Project) =>
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
if (haveCommonNonIdempotentOutput(p1.projectList, p2.projectList)) {
p1
} else {
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
}
case p @ Project(_, agg: Aggregate) =>
if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) {
if (haveCommonNonIdempotentOutput(p.projectList, agg.aggregateExpressions)) {
p
} else {
agg.copy(aggregateExpressions = buildCleanedProjectList(
Expand All @@ -669,7 +669,7 @@ object CollapseProject extends Rule[LogicalPlan] {
})
}

private def haveCommonNonDeterministicOutput(
private def haveCommonNonIdempotentOutput(
upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
// Create a map of Aliases to their values from the lower projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
Expand All @@ -679,7 +679,7 @@ object CollapseProject extends Rule[LogicalPlan] {
// deterministic.
upper.exists(_.collect {
case a: Attribute if aliases.contains(a) => aliases(a).child
}.exists(!_.deterministic))
}.exists(!_.idempotent))
}

private def buildCleanedProjectList(
Expand Down Expand Up @@ -755,8 +755,8 @@ object TransposeWindow extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
if w1.references.intersect(w2.windowOutputSet).isEmpty &&
w1.expressions.forall(_.deterministic) &&
w2.expressions.forall(_.deterministic) &&
w1.expressions.forall(_.idempotent) &&
w2.expressions.forall(_.idempotent) &&
compatibleParititions(ps1, ps2) =>
Project(w1.output, Window(we2, ps2, os2, Window(we1, ps1, os1, grandChild)))
}
Expand Down Expand Up @@ -831,7 +831,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
val newPredicates = constraints
.union(constructIsNotNullConstraints(constraints, plan.output))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.idempotent
} -- plan.constraints
if (newPredicates.isEmpty) {
plan
Expand Down Expand Up @@ -874,8 +874,8 @@ object CombineUnions extends Rule[LogicalPlan] {
object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// The query execution/optimization does not guarantee the expressions are evaluated in order.
// We only can combine them if and only if both are deterministic.
case Filter(fc, nf @ Filter(nc, grandChild)) if fc.deterministic && nc.deterministic =>
// We only can combine them if and only if both are idempotent.
case Filter(fc, nf @ Filter(nc, grandChild)) if fc.idempotent && nc.idempotent =>
(ExpressionSet(splitConjunctivePredicates(fc)) --
ExpressionSet(splitConjunctivePredicates(nc))).reduceOption(And) match {
case Some(ac) =>
Expand Down Expand Up @@ -917,8 +917,8 @@ object RemoveRedundantSorts extends Rule[LogicalPlan] {
}

def canEliminateSort(plan: LogicalPlan): Boolean = plan match {
case p: Project => p.projectList.forall(_.deterministic)
case f: Filter => f.condition.deterministic
case p: Project => p.projectList.forall(_.idempotent)
case f: Filter => f.condition.idempotent
case _: ResolvedHint => true
case _ => false
}
Expand All @@ -940,12 +940,12 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
LocalRelation(child.output, data = Seq.empty, isStreaming = plan.isStreaming)
case Filter(Literal(false, BooleanType), child) =>
LocalRelation(child.output, data = Seq.empty, isStreaming = plan.isStreaming)
// If any deterministic condition is guaranteed to be true given the constraints on the child's
// If any idempotent condition is guaranteed to be true given the constraints on the child's
// output, remove the condition
case f @ Filter(fc, p: LogicalPlan) =>
val (prunedPredicates, remainingPredicates) =
splitConjunctivePredicates(fc).partition { cond =>
cond.deterministic && p.constraints.contains(cond)
cond.idempotent && p.constraints.contains(cond)
}
if (prunedPredicates.isEmpty) {
f
Expand All @@ -968,13 +968,13 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// SPARK-13473: We can't push the predicate down when the underlying projection output non-
// deterministic field(s). Non-deterministic expressions are essentially stateful. This
// idempotent field(s). Non-idempotent expressions are essentially stateful. This
// implies that, for a given input row, the output are determined by the expression's initial
// state and all the input rows processed before. In another word, the order of input rows
// matters for non-deterministic expressions, while pushing down predicates changes the order.
// matters for non-idempotent expressions, while pushing down predicates changes the order.
// This also applies to Aggregate.
case Filter(condition, project @ Project(fields, grandChild))
if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) =>
if fields.forall(_.idempotent) && canPushThroughCondition(grandChild, condition) =>

// Create a map of Aliases to their values from the child projection.
// e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
Expand All @@ -985,7 +985,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))

case filter @ Filter(condition, aggregate: Aggregate)
if aggregate.aggregateExpressions.forall(_.deterministic)
if aggregate.aggregateExpressions.forall(_.idempotent)
&& aggregate.groupingExpressions.nonEmpty =>
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression, and create a map from the alias to the expression
Expand All @@ -997,7 +997,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
// For each filter, expand the alias and check if the filter can be evaluated using
// attributes produced by the aggregate operator's child operator.
val (candidates, nonDeterministic) =
splitConjunctivePredicates(condition).partition(_.deterministic)
splitConjunctivePredicates(condition).partition(_.idempotent)

val (pushDown, rest) = candidates.partition { cond =>
val replaced = replaceAlias(cond, aliasMap)
Expand All @@ -1020,14 +1020,14 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
// Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be
// pushed beneath must satisfy the following conditions:
// 1. All the expressions are part of window partitioning key. The expressions can be compound.
// 2. Deterministic.
// 3. Placed before any non-deterministic predicates.
// 2. Idempotent.
// 3. Placed before any non-idempotent predicates.
case filter @ Filter(condition, w: Window)
if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) =>
val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references))

val (candidates, nonDeterministic) =
splitConjunctivePredicates(condition).partition(_.deterministic)
splitConjunctivePredicates(condition).partition(_.idempotent)

val (pushDown, rest) = candidates.partition { cond =>
cond.references.subsetOf(partitionAttrs)
Expand All @@ -1044,8 +1044,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
}

case filter @ Filter(condition, union: Union) =>
// Union could change the rows, so non-deterministic predicate can't be pushed down
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition(_.deterministic)
// Union could change the rows, so non-idempotent predicate can't be pushed down
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition(_.idempotent)

if (pushDown.nonEmpty) {
val pushDownCond = pushDown.reduceLeft(And)
Expand All @@ -1070,7 +1070,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {

case filter @ Filter(condition, watermark: EventTimeWatermark) =>
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { p =>
p.deterministic && !p.references.contains(watermark.eventTime)
p.idempotent && !p.references.contains(watermark.eventTime)
}

if (pushDown.nonEmpty) {
Expand All @@ -1084,7 +1084,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
}

case filter @ Filter(_, u: UnaryNode)
if canPushThrough(u) && u.expressions.forall(_.deterministic) =>
if canPushThrough(u) && u.expressions.forall(_.idempotent) =>
pushDownPredicate(filter, u.child) { predicate =>
u.withNewChildren(Seq(Filter(predicate, u.child)))
}
Expand All @@ -1108,18 +1108,18 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
private def pushDownPredicate(
filter: Filter,
grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = {
// Only push down the predicates that is deterministic and all the referenced attributes
// Only push down the predicates that is idempotent and all the referenced attributes
// come from grandchild.
// TODO: non-deterministic predicates could be pushed through some operators that do not change
// TODO: non-idempotent predicates could be pushed through some operators that do not change
// the rows.
val (candidates, nonDeterministic) =
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
val (candidates, nonIdempotent) =
splitConjunctivePredicates(filter.condition).partition(_.idempotent)

val (pushDown, rest) = candidates.partition { cond =>
cond.references.subsetOf(grandchild.outputSet)
}

val stayUp = rest ++ nonDeterministic
val stayUp = rest ++ nonIdempotent

if (pushDown.nonEmpty) {
val newChild = insertFilter(pushDown.reduceLeft(And))
Expand Down Expand Up @@ -1162,13 +1162,13 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
/**
* Splits join condition expressions or filter predicates (on a given join's output) into three
* categories based on the attributes required to evaluate them. Note that we explicitly exclude
* non-deterministic (i.e., stateful) condition expressions in canEvaluateInLeft or
* non-idempotent (i.e., stateful) condition expressions in canEvaluateInLeft or
* canEvaluateInRight to prevent pushing these predicates on either side of the join.
*
* @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth)
*/
private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
val (pushDownCandidates, nonDeterministic) = condition.partition(_.deterministic)
val (pushDownCandidates, nonDeterministic) = condition.partition(_.idempotent)
val (leftEvaluateCondition, rest) =
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
val (rightEvaluateCondition, commonCondition) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
// grouping expressions.
val groupingExpressionSet = collectGroupingExpressions(q)
q transformExpressionsDown {
case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
case a: Add if a.idempotent && a.dataType.isInstanceOf[IntegralType] =>
val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable)
if (foldables.size > 1) {
val foldableExpr = foldables.reduce((x, y) => Add(x, y))
Expand All @@ -193,7 +193,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
} else {
a
}
case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] =>
case m: Multiply if m.idempotent && m.dataType.isInstanceOf[IntegralType] =>
val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable)
if (foldables.size > 1) {
val foldableExpr = foldables.reduce((x, y) => Multiply(x, y))
Expand Down Expand Up @@ -418,7 +418,7 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case If(FalseLiteral, _, falseValue) => falseValue
case If(Literal(null, _), _, falseValue) => falseValue
case If(cond, trueValue, falseValue)
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
if cond.idempotent && trueValue.semanticEquals(falseValue) => trueValue

case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
// If there are branches that are always false, remove them.
Expand Down Expand Up @@ -448,11 +448,11 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
if branches.forall(_._2.semanticEquals(elseValue)) =>
// For non-deterministic conditions with side effect, we can not remove it, or change
// the ordering. As a result, we try to remove the deterministic conditions from the tail.
var hitNonDeterministicCond = false
var hitNonIdempotentCond = false
var i = branches.length
while (i > 0 && !hitNonDeterministicCond) {
hitNonDeterministicCond = !branches(i - 1)._1.deterministic
if (!hitNonDeterministicCond) {
while (i > 0 && !hitNonIdempotentCond) {
hitNonIdempotentCond = !branches(i - 1)._1.idempotent
if (!hitNonIdempotentCond) {
i -= 1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
* Returns whether the expression returns null or false when all inputs are nulls.
*/
private def canFilterOutNull(e: Expression): Boolean = {
if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false
if (!e.idempotent || SubqueryExpression.hasCorrelatedSubquery(e)) return false
val attributes = e.references.toSeq
val emptyRow = new GenericInternalRow(attributes.length)
val boundE = BindReferences.bindReference(e, attributes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ object PhysicalOperation extends PredicateHelper {
private def collectProjectsAndFilters(plan: LogicalPlan):
(Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) =
plan match {
case Project(fields, child) if fields.forall(_.deterministic) =>
case Project(fields, child) if fields.forall(_.idempotent) =>
val (_, filters, other, aliases) = collectProjectsAndFilters(child)
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))

case Filter(condition, child) if condition.deterministic =>
case Filter(condition, child) if condition.idempotent =>
val (fields, filters, other, aliases) = collectProjectsAndFilters(child)
val substitutedCondition = substitute(aliases)(condition)
(fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan =>
.union(inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints, output))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.idempotent
}
)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -166,4 +167,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
Literal(1))
)
}

test("SPARK-24913: don't skip AssertNotNull and AssertTrue") {
val ifWithAssertNotNull = If(AssertNotNull(UnresolvedAttribute("b")), Literal(1), Literal(1))
val ifWithAssertTrue = If(AssertTrue(UnresolvedAttribute("b")), Literal(1), Literal(1))
val plan = Filter(And(ifWithAssertNotNull, ifWithAssertTrue), OneRowRelation())
val optimized = Optimize.execute(plan).analyze
// optimization should not change the plan
comparePlans(plan, optimized, checkAnalysis = false)
}
}
Loading