From c2f9b058f6ad243e35518bc1758fe7ef7ac2c25f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Mar 2016 14:02:12 +0000 Subject: [PATCH 01/10] init import. --- .../sql/catalyst/planning/patterns.scala | 22 +++++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 8 +++++++ 2 files changed, 30 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 62d54df98ecc5..96eede403e9f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -202,3 +202,25 @@ object Unions { } } } + +/** + * A pattern that finds the original expression from a sequence of casts. + */ +object Casts { + def unapply(expr: Expression): Option[Attribute] = expr match { + case c: Cast => collectCasts(expr) + case _ => None + } + + private def collectCasts(e: Expression): Option[Attribute] = { + if (e.isInstanceOf[Cast]) { + collectCasts(e.children(0)) + } else { + if (e.isInstanceOf[Attribute]) { + Some(e.asInstanceOf[Attribute]) + } else { + None + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 920e989d058dc..71b9f762d49e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.Casts import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} @@ -36,6 +37,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT .union(constructIsNotNullConstraints(constraints)) .filter(constraint => constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + .map(_.transform { + case n @ IsNotNull(c) => + c match { + case Casts(a) if outputSet.contains(a) => IsNotNull(a) + case _ => n + } + }) } /** From b4e60339f713590fdb7294ab735eaefbe1940501 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 18 Mar 2016 02:41:16 +0000 Subject: [PATCH 02/10] Add test. --- .../sql/catalyst/plans/ConstraintPropagationSuite.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a9375a740daac..0478c9b9b7be7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.LongType class ConstraintPropagationSuite extends SparkFunSuite { @@ -217,4 +218,12 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b"))))) } + + test("infer constraints on cast") { + val tr = LocalRelation('a.int, 'b.long) + verifyConstraints(tr.where('a.attr === 'b.attr).analyze.constraints, + ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b"))))) + } } From d6bb43b4b6f6f7240973a8074024435183fd7d0a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 22 Mar 2016 01:04:49 +0000 Subject: [PATCH 03/10] Address comments. --- .../spark/sql/catalyst/planning/patterns.scala | 13 +++++-------- .../apache/spark/sql/catalyst/plans/QueryPlan.scala | 11 +++-------- .../catalyst/plans/ConstraintPropagationSuite.scala | 4 +++- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 96eede403e9f1..681f06ed1ecf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -212,15 +212,12 @@ object Casts { case _ => None } + @tailrec private def collectCasts(e: Expression): Option[Attribute] = { - if (e.isInstanceOf[Cast]) { - collectCasts(e.children(0)) - } else { - if (e.isInstanceOf[Attribute]) { - Some(e.asInstanceOf[Attribute]) - } else { - None - } + e match { + case e: Cast => collectCasts(e.child) + case e: Attribute => Some(e) + case _ => None } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 71b9f762d49e1..5b603f6339bb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -37,13 +37,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT .union(constructIsNotNullConstraints(constraints)) .filter(constraint => constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) - .map(_.transform { - case n @ IsNotNull(c) => - c match { - case Casts(a) if outputSet.contains(a) => IsNotNull(a) - case _ => n - } - }) } /** @@ -69,7 +62,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT Set(IsNotNull(l), IsNotNull(r)) case _ => Set.empty[Expression] - }.foldLeft(Set.empty[Expression])(_ union _.toSet) + }.foldLeft(Set.empty[Expression])(_ union _.toSet).map(_.transform { + case IsNotNull(Casts(a)) => IsNotNull(a) + }) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 0478c9b9b7be7..96015adb52fdc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -221,7 +221,9 @@ class ConstraintPropagationSuite extends SparkFunSuite { test("infer constraints on cast") { val tr = LocalRelation('a.int, 'b.long) - verifyConstraints(tr.where('a.attr === 'b.attr).analyze.constraints, + verifyConstraints( + tr.where('a.attr === 'b.attr && + IsNotNull(Cast(Cast(resolveColumn(tr, "a"), LongType), LongType))).analyze.constraints, ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b"))))) From 3d086259a2a447b322b69eec15274ea46160570b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 22 Mar 2016 06:47:02 +0000 Subject: [PATCH 04/10] Add nullIntolerant as a method to Expression. --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 16a1b2aee2730..c836f43da8144 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -80,6 +80,12 @@ abstract class Expression extends TreeNode[Expression] { def nullable: Boolean + /** + * Indicates whether this expression is null intolerant + * (i.e., any null input will result in null output). + */ + def nullIntolerant: Boolean = false + def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ From 5fa1760032628c9219d6c064526ee5a34b3de793 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 22 Mar 2016 10:43:29 +0000 Subject: [PATCH 05/10] Address comments. --- .../sql/catalyst/expressions/Expression.scala | 6 ++-- .../expressions/nullExpressions.scala | 7 +++++ .../sql/catalyst/expressions/predicates.scala | 2 ++ .../spark/sql/catalyst/plans/QueryPlan.scala | 29 +++++++------------ .../plans/ConstraintPropagationSuite.scala | 12 ++++++-- 5 files changed, 32 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c836f43da8144..b7b2b9a438dcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -81,10 +81,10 @@ abstract class Expression extends TreeNode[Expression] { def nullable: Boolean /** - * Indicates whether this expression is null intolerant - * (i.e., any null input will result in null output). + * Indicates whether this expression is null intolerant. If this is true, + * then any null input will result in null output). */ - def nullIntolerant: Boolean = false + def nullIntolerant: Boolean = true def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index e22026d584654..4b577ede0173b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -96,6 +96,8 @@ case class IsNaN(child: Expression) extends UnaryExpression override def nullable: Boolean = false + override def nullIntolerant: Boolean = false + override def eval(input: InternalRow): Any = { val value = child.eval(input) if (value == null) { @@ -183,6 +185,8 @@ case class NaNvl(left: Expression, right: Expression) case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false + override def nullIntolerant: Boolean = false + override def eval(input: InternalRow): Any = { child.eval(input) == null } @@ -204,6 +208,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false + override def nullIntolerant: Boolean = false + override def eval(input: InternalRow): Any = { child.eval(input) != null } @@ -224,6 +230,7 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { */ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false + override def nullIntolerant: Boolean = false override def foldable: Boolean = children.forall(_.foldable) override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 20818bfb1a514..59da16fc30c70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -408,6 +408,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def nullable: Boolean = false + override def nullIntolerant: Boolean = false + override def eval(input: InternalRow): Any = { val input1 = left.eval(input) val input2 = right.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 5b603f6339bb4..c14dfaf6a1c29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -47,24 +47,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { // Currently we only propagate constraints if the condition consists of equality // and ranges. For all other cases, we return an empty set of constraints - constraints.map { - case EqualTo(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case GreaterThan(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case GreaterThanOrEqual(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case LessThan(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case LessThanOrEqual(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case Not(EqualTo(l, r)) => - Set(IsNotNull(l), IsNotNull(r)) - case _ => - Set.empty[Expression] - }.foldLeft(Set.empty[Expression])(_ union _.toSet).map(_.transform { - case IsNotNull(Casts(a)) => IsNotNull(a) - }) + constraints.map(scanNullIntolerantExpr) + .foldLeft(Set.empty[Expression])(_ union _.toSet) + } + + private def scanNullIntolerantExpr(expr: Expression): Set[Expression] = expr match { + case a: Attribute => Set(IsNotNull(a)) + case IsNotNull(e) => + // IsNotNull is null tolerant, but we need to explore for the attributes not null. + scanNullIntolerantExpr(e) + case e: Expression if e.nullIntolerant => e.children.flatMap(scanNullIntolerantExpr).toSet + case _ => Set.empty[Expression] } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 96015adb52fdc..a280c111ca4c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -220,12 +220,18 @@ class ConstraintPropagationSuite extends SparkFunSuite { } test("infer constraints on cast") { - val tr = LocalRelation('a.int, 'b.long) + val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) verifyConstraints( tr.where('a.attr === 'b.attr && - IsNotNull(Cast(Cast(resolveColumn(tr, "a"), LongType), LongType))).analyze.constraints, + 'c.attr + 100 > 'd.attr && + IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints, ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"), + Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"), IsNotNull(resolveColumn(tr, "a")), - IsNotNull(resolveColumn(tr, "b"))))) + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e")), + IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))))) } } From 3373396692ee8c3f1ac7694c36da46da8437f54d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 22 Mar 2016 13:45:13 +0000 Subject: [PATCH 06/10] Fix test. --- .../org/apache/spark/sql/hive/orc/OrcFilterSuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala index 7b0c7a9f00514..8c52ebd532165 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -226,9 +226,10 @@ class OrcFilterSuite extends QueryTest with OrcTest { ) checkFilterPredicate( '_1 < 2 || '_1 > 3, - """leaf-0 = (LESS_THAN _1 2) - |leaf-1 = (LESS_THAN_EQUALS _1 3) - |expr = (or leaf-0 (not leaf-1))""".stripMargin.trim + """leaf-0 = (IS_NULL _1) + |leaf-1 = (LESS_THAN _1 2) + |leaf-2 = (LESS_THAN_EQUALS _1 3) + |expr = (and (not leaf-0) (or leaf-1 (not leaf-2)))""".stripMargin.trim ) checkFilterPredicate( '_1 < 2 && '_1 > 3, From 56ca15fa348d0488ca689f9fec2dd912d0625fc4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 Mar 2016 03:48:29 +0000 Subject: [PATCH 07/10] Use trait for null intolerant expression. --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 6 ----- .../sql/catalyst/expressions/arithmetic.scala | 25 ++++++++++++------- .../expressions/nullExpressions.scala | 7 ------ .../sql/catalyst/expressions/package.scala | 7 ++++++ .../sql/catalyst/expressions/predicates.scala | 19 ++++++++------ .../spark/sql/catalyst/plans/QueryPlan.scala | 18 ++++++------- 7 files changed, 44 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a965cc8d5332b..d842ffdc6637c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -112,7 +112,7 @@ object Cast { } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant { override def toString: String = s"cast($child as ${dataType.simpleString})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6875915f79b15..5f8899d5998a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -80,12 +80,6 @@ abstract class Expression extends TreeNode[Expression] { def nullable: Boolean - /** - * Indicates whether this expression is null intolerant. If this is true, - * then any null input will result in null output). - */ - def nullIntolerant: Boolean = true - def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ed812e06799a9..a475718fa6205 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class UnaryMinus(child: Expression) extends UnaryExpression + with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -58,7 +59,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def sql: String = s"(-${child.sql})" } -case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class UnaryPositive(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def prettyName: String = "positive" override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -79,7 +81,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects @ExpressionDescription( usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", extended = "> SELECT _FUNC_('-1');\n1") -case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Abs(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) @@ -123,7 +126,7 @@ private[sql] object BinaryArithmetic { def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } -case class Add(left: Expression, right: Expression) extends BinaryArithmetic { +case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -152,7 +155,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { } } -case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { +case class Subtract(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -181,7 +185,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti } } -case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { +case class Multiply(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -193,7 +198,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { +case class Divide(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -255,7 +261,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +case class Remainder(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -429,7 +436,7 @@ case class MinOf(left: Expression, right: Expression) override def symbol: String = "min" } -case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 4b577ede0173b..e22026d584654 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -96,8 +96,6 @@ case class IsNaN(child: Expression) extends UnaryExpression override def nullable: Boolean = false - override def nullIntolerant: Boolean = false - override def eval(input: InternalRow): Any = { val value = child.eval(input) if (value == null) { @@ -185,8 +183,6 @@ case class NaNvl(left: Expression, right: Expression) case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false - override def nullIntolerant: Boolean = false - override def eval(input: InternalRow): Any = { child.eval(input) == null } @@ -208,8 +204,6 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false - override def nullIntolerant: Boolean = false - override def eval(input: InternalRow): Any = { child.eval(input) != null } @@ -230,7 +224,6 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { */ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false - override def nullIntolerant: Boolean = false override def foldable: Boolean = children.forall(_.foldable) override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index f1fa13daa77eb..23baa6f7837fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -92,4 +92,11 @@ package object expressions { StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) } } + + /** + * When an expression inherits this, meaning the expression is null intolerant (i.e. any null + * input will result in null output). We will use this information during constructing IsNotNull + * constraints. + */ + trait NullIntolerant } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 59da16fc30c70..dbf8b3d2bace4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -90,7 +90,7 @@ trait PredicateHelper { case class Not(child: Expression) - extends UnaryExpression with Predicate with ImplicitCastInputTypes { + extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant { override def toString: String = s"NOT $child" @@ -376,7 +376,8 @@ private[sql] object Equality { } -case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { +case class EqualTo(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = AnyDataType @@ -408,8 +409,6 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def nullable: Boolean = false - override def nullIntolerant: Boolean = false - override def eval(input: InternalRow): Any = { val input1 = left.eval(input) val input2 = right.eval(input) @@ -443,7 +442,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } -case class LessThan(left: Expression, right: Expression) extends BinaryComparison { +case class LessThan(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -455,7 +455,8 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso } -case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { +case class LessThanOrEqual(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -467,7 +468,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo } -case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { +case class GreaterThan(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -479,7 +481,8 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar } -case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { +case class GreaterThanOrEqual(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index e17c886123288..f42b36aac59fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -46,17 +46,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { // Currently we only propagate constraints if the condition consists of equality // and ranges. For all other cases, we return an empty set of constraints - constraints.map(scanNullIntolerantExpr) - .foldLeft(Set.empty[Expression])(_ union _.toSet) + constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_)) } - private def scanNullIntolerantExpr(expr: Expression): Set[Expression] = expr match { - case a: Attribute => Set(IsNotNull(a)) - case IsNotNull(e) => - // IsNotNull is null tolerant, but we need to explore for the attributes not null. - scanNullIntolerantExpr(e) - case e: Expression if e.nullIntolerant => e.children.flatMap(scanNullIntolerantExpr).toSet - case _ => Set.empty[Expression] + /** + * Recursively explores the expressions which are null intolerant and returns all attributes + * in these expressions. + */ + private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match { + case a: Attribute => Seq(a) + case _: NullIntolerant | _: IsNotNull => expr.children.flatMap(scanNullIntolerantExpr) + case _ => Seq.empty[Attribute] } /** From c8fb736cf9e1d8b3626e7ab2323da4db9135ff9b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 Mar 2016 06:24:00 +0000 Subject: [PATCH 08/10] Or is not null intolerant. --- .../org/apache/spark/sql/hive/orc/OrcFilterSuite.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala index 8c52ebd532165..7b0c7a9f00514 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -226,10 +226,9 @@ class OrcFilterSuite extends QueryTest with OrcTest { ) checkFilterPredicate( '_1 < 2 || '_1 > 3, - """leaf-0 = (IS_NULL _1) - |leaf-1 = (LESS_THAN _1 2) - |leaf-2 = (LESS_THAN_EQUALS _1 3) - |expr = (and (not leaf-0) (or leaf-1 (not leaf-2)))""".stripMargin.trim + """leaf-0 = (LESS_THAN _1 2) + |leaf-1 = (LESS_THAN_EQUALS _1 3) + |expr = (or leaf-0 (not leaf-1))""".stripMargin.trim ) checkFilterPredicate( '_1 < 2 && '_1 > 3, From 81c46c72117f30679f8d11c908340dc9067a14e7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 Mar 2016 07:48:45 +0000 Subject: [PATCH 09/10] Add more tests. --- .../plans/ConstraintPropagationSuite.scala | 59 ++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a280c111ca4c1..f39d01668bead 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.{DoubleType, LongType} class ConstraintPropagationSuite extends SparkFunSuite { @@ -234,4 +234,61 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "e")), IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))))) } + + test("infer isnotnull constraints from compound expressions") { + val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) + verifyConstraints( + tr.where('a.attr + 'b.attr === 'c.attr && + IsNotNull( + Cast( + Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") === + Cast(resolveColumn(tr, "c"), LongType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "e")), + IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))))) + + verifyConstraints( + tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) === + Cast(resolveColumn(tr, "c"), LongType), + Cast(resolveColumn(tr, "d"), DoubleType) / + Cast(Cast(10, LongType), DoubleType) === + Cast(resolveColumn(tr, "e"), DoubleType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + verifyConstraints( + tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >= + Cast(resolveColumn(tr, "c"), LongType), + Cast(resolveColumn(tr, "d"), DoubleType) / + Cast(Cast(10, LongType), DoubleType) < + Cast(resolveColumn(tr, "e"), DoubleType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + verifyConstraints( + tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints, + ExpressionSet(Seq( + (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) - + (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) > + Cast(resolveColumn(tr, "e") * 1000, LongType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + } } From 8e8dd72090859c65fc23b5a85745b85df54473dc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 24 Mar 2016 06:23:37 +0000 Subject: [PATCH 10/10] Consider IsNotNull(IsNotNull(expr)) case. Add a test for it too. --- .../sql/catalyst/expressions/namedExpressions.scala | 2 +- .../apache/spark/sql/catalyst/plans/QueryPlan.scala | 3 ++- .../catalyst/plans/ConstraintPropagationSuite.scala | 10 ++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index a5b5758167276..262582ca5d9c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -97,7 +97,7 @@ trait NamedExpression extends Expression { } } -abstract class Attribute extends LeafExpression with NamedExpression { +abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant { override def references: AttributeSet = AttributeSet(this) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index f42b36aac59fd..d9be7ab055f6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -55,7 +55,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match { case a: Attribute => Seq(a) - case _: NullIntolerant | _: IsNotNull => expr.children.flatMap(scanNullIntolerantExpr) + case _: NullIntolerant | IsNotNull(_: NullIntolerant) => + expr.children.flatMap(scanNullIntolerantExpr) case _ => Seq.empty[Attribute] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index f39d01668bead..f25ede6d1a5ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -290,5 +290,15 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "c")), IsNotNull(resolveColumn(tr, "d")), IsNotNull(resolveColumn(tr, "e"))))) + + // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null. + verifyConstraints( + tr.where('a.attr === 'c.attr && + IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints, + ExpressionSet(Seq( + resolveColumn(tr, "a") === resolveColumn(tr, "c"), + IsNotNull(IsNotNull(resolveColumn(tr, "b"))), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c"))))) } }