diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 4cff162c116a4..ae3b75dc3334b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI -import java.util.Locale import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell @@ -26,7 +25,7 @@ import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, Predicate} object ExternalCatalogUtils { // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't @@ -148,7 +147,7 @@ object ExternalCatalogUtils { } val boundPredicate = - InterpretedPredicate.create(predicates.reduce(And).transform { + Predicate.createInterpreted(predicates.reduce(And).transform { case att: AttributeReference => val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 300f075d32763..b4a85e3e50bec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -89,14 +89,14 @@ object MutableProjection } /** - * Returns an MutableProjection for given sequence of bound Expressions. + * Returns a MutableProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): MutableProjection = { createObject(exprs) } /** - * Returns an MutableProjection for given sequence of Expressions, which will be bound to + * Returns a MutableProjection for given sequence of Expressions, which will be bound to * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index e0fabad6d089a..6ba646d360d2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -20,31 +20,17 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -/** - * Interface for generated predicate - */ -abstract class Predicate { - def eval(r: InternalRow): Boolean - - /** - * Initializes internal states given the current partition index. - * This is used by nondeterministic expressions to set initial states. - * The default implementation does nothing. - */ - def initialize(partitionIndex: Int): Unit = {} -} - /** * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]]. */ -object GeneratePredicate extends CodeGenerator[Expression, Predicate] { +object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = BindReferences.bindReference(in, inputSchema) - protected def create(predicate: Expression): Predicate = { + protected def create(predicate: Expression): BasePredicate = { val ctx = newCodeGenContext() val eval = predicate.genCode(ctx) @@ -53,7 +39,7 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { return new SpecificPredicate(references); } - class SpecificPredicate extends ${classOf[Predicate].getName} { + class SpecificPredicate extends ${classOf[BasePredicate].getName} { private final Object[] references; ${ctx.declareMutableStates()} @@ -79,6 +65,6 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") val (clazz, _) = CodeGenerator.compile(code) - clazz.generate(ctx.references.toArray).asInstanceOf[Predicate] + clazz.generate(ctx.references.toArray).asInstanceOf[BasePredicate] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4c0998412f729..bcd442ad3cc35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,8 +21,9 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.TypeUtils @@ -30,11 +31,18 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -object InterpretedPredicate { - def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = - create(BindReferences.bindReference(expression, inputSchema)) +/** + * A base class for generated/interpreted predicate + */ +abstract class BasePredicate { + def eval(r: InternalRow): Boolean - def create(expression: Expression): InterpretedPredicate = new InterpretedPredicate(expression) + /** + * Initializes internal states given the current partition index. + * This is used by nondeterministic expressions to set initial states. + * The default implementation does nothing. + */ + def initialize(partitionIndex: Int): Unit = {} } case class InterpretedPredicate(expression: Expression) extends BasePredicate { @@ -56,6 +64,35 @@ trait Predicate extends Expression { override def dataType: DataType = BooleanType } +/** + * The factory object for `BasePredicate`. + */ +object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePredicate] { + + override protected def createCodeGeneratedObject(in: Expression): BasePredicate = { + GeneratePredicate.generate(in) + } + + override protected def createInterpretedObject(in: Expression): BasePredicate = { + InterpretedPredicate(in) + } + + def createInterpreted(e: Expression): InterpretedPredicate = InterpretedPredicate(e) + + /** + * Returns a BasePredicate for an Expression, which will be bound to `inputSchema`. + */ + def create(e: Expression, inputSchema: Seq[Attribute]): BasePredicate = { + createObject(bindReference(e, inputSchema)) + } + + /** + * Returns a BasePredicate for a given bound Expression. + */ + def create(e: Expression): BasePredicate = { + createObject(e) + } +} trait PredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b78bdf082f333..f855596c363be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1507,7 +1507,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Filter(condition, LocalRelation(output, data, isStreaming)) if !hasUnevaluableExpr(condition) => - val predicate = InterpretedPredicate.create(condition, output) + val predicate = Predicate.create(condition, output) predicate.initialize(0) LocalRelation(output, data.filter(row => predicate.eval(row)), isStreaming) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 52cdd988caa2e..67a41e7cc2767 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -510,7 +510,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } test("Interpreted Predicate should initialize nondeterministic expressions") { - val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0))) + val interpreted = Predicate.create(LessThan(Rand(7), Literal(1.0))) interpreted.initialize(0) assert(interpreted.eval(new UnsafeRow())) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index b0fe4b741479f..88f5673aa9a1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -230,7 +230,7 @@ case class FileSourceScanExec( // call the file index for the files matching all filters except dynamic partition filters val predicate = dynamicPartitionFilters.reduce(And) val partitionColumns = relation.partitionSchema - val boundPredicate = newPredicate(predicate.transform { + val boundPredicate = Predicate.create(predicate.transform { case a: AttributeReference => val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 125f76282e3df..738af995376e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -21,7 +21,6 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.InternalCompilerException @@ -33,7 +32,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ @@ -471,28 +470,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ MutableProjection.create(expressions, inputSchema) } - private def genInterpretedPredicate( - expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = { - val str = expression.toString - val logMessage = if (str.length > 256) { - str.substring(0, 256 - 3) + "..." - } else { - str - } - logWarning(s"Codegen disabled for this expression:\n $logMessage") - InterpretedPredicate.create(expression, inputSchema) - } - - protected def newPredicate( - expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case _ @ (_: InternalCompilerException | _: CompileException) if codeGenFallBack => - genInterpretedPredicate(expression, inputSchema) - } - } - protected def newOrdering( order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = { GenerateOrdering.generate(order, inputSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 3ed42f359c0a4..e128d59dca6ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{LongType, StructType} @@ -227,7 +226,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithIndexInternal { (index, iter) => - val predicate = newPredicate(condition, child.output) + val predicate = Predicate.create(condition, child.output) predicate.initialize(0) iter.filter { row => val r = predicate.eval(row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 8d13cfb93d270..f03c2586048bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -310,7 +310,7 @@ case class InMemoryTableScanExec( val buffers = relation.cacheBuilder.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => - val partitionFilter = newPredicate( + val partitionFilter = Predicate.create( partitionFilters.reduceOption(And).getOrElse(Literal(true)), schema) partitionFilter.initialize(index) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 3adec2f790730..21ddeb6491155 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -171,7 +171,7 @@ abstract class PartitioningAwareFileIndex( if (partitionPruningPredicates.nonEmpty) { val predicate = partitionPruningPredicates.reduce(expressions.And) - val boundPredicate = InterpretedPredicate.create(predicate.transform { + val boundPredicate = Predicate.createInterpreted(predicate.transform { case a: AttributeReference => val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index f526a19876670..5517c0dcdb188 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -19,14 +19,12 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.collection.{BitSet, CompactBuffer} case class BroadcastNestedLoopJoinExec( @@ -84,7 +82,7 @@ case class BroadcastNestedLoopJoinExec( @transient private lazy val boundCondition = { if (condition.isDefined) { - newPredicate(condition.get, streamed.output ++ broadcast.output).eval _ + Predicate.create(condition.get, streamed.output ++ broadcast.output).eval _ } else { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 88d98530991c9..29645a736548c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark._ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Predicate, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution.{BinaryExecNode, ExplainUtils, ExternalAppendOnlyUnsafeRowArray, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.CompletionIterator @@ -93,7 +92,7 @@ case class CartesianProductExec( pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { - val boundCondition = newPredicate(condition.get, left.output ++ right.output) + val boundCondition = Predicate.create(condition.get, left.output ++ right.output) boundCondition.initialize(index) val joined = new JoinedRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index e8938cb22e890..137f0b87a2f3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -99,7 +99,7 @@ trait HashJoin { UnsafeProjection.create(streamedKeys) @transient private[this] lazy val boundCondition = if (condition.isDefined) { - newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval _ + Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _ } else { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 26fb0e5ffb1af..cd3c596435a21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -168,7 +168,7 @@ case class SortMergeJoinExec( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => - newPredicate(cond, left.output ++ right.output).eval _ + Predicate.create(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 6bb4dc1672900..f1bfe97610fed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -233,8 +233,9 @@ case class StreamingSymmetricHashJoinExec( val joinedRow = new JoinedRow + val inputSchema = left.output ++ right.output val postJoinFilter = - newPredicate(condition.bothSides.getOrElse(Literal(true)), left.output ++ right.output).eval _ + Predicate.create(condition.bothSides.getOrElse(Literal(true)), inputSchema).eval _ val leftSideJoiner = new OneSideHashJoiner( LeftSide, left.output, leftKeys, leftInputIter, condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left) @@ -417,7 +418,7 @@ case class StreamingSymmetricHashJoinExec( // Filter the joined rows based on the given condition. val preJoinFilter = - newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ + Predicate.create(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ private val joinStateManager = new SymmetricHashJoinStateManager( joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value, @@ -428,16 +429,16 @@ case class StreamingSymmetricHashJoinExec( case Some(JoinStateKeyWatermarkPredicate(expr)) => // inputSchema can be empty as expr should only have BoundReferences and does not require // the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]]. - newPredicate(expr, Seq.empty).eval _ + Predicate.create(expr, Seq.empty).eval _ case _ => - newPredicate(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate + Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate } private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match { case Some(JoinStateValueWatermarkPredicate(expr)) => - newPredicate(expr, inputAttributes).eval _ + Predicate.create(expr, inputAttributes).eval _ case _ => - newPredicate(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate + Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate } private[this] var updatedStateRowsCount = 0 @@ -457,7 +458,7 @@ case class StreamingSymmetricHashJoinExec( val nonLateRows = WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match { case Some(watermarkExpr) => - val predicate = newPredicate(watermarkExpr, inputAttributes) + val predicate = Predicate.create(watermarkExpr, inputAttributes) inputIter.filter { row => !predicate.eval(row) } case None => inputIter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d689a6f3c9819..01b309c3cf345 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ @@ -156,17 +156,17 @@ trait WatermarkSupport extends UnaryExecNode { } /** Predicate based on keys that matches data older than the watermark */ - lazy val watermarkPredicateForKeys: Option[Predicate] = watermarkExpression.flatMap { e => + lazy val watermarkPredicateForKeys: Option[BasePredicate] = watermarkExpression.flatMap { e => if (keyExpressions.exists(_.metadata.contains(EventTimeWatermark.delayKey))) { - Some(newPredicate(e, keyExpressions)) + Some(Predicate.create(e, keyExpressions)) } else { None } } /** Predicate based on the child output that matches data older than the watermark. */ - lazy val watermarkPredicateForData: Option[Predicate] = - watermarkExpression.map(newPredicate(_, child.output)) + lazy val watermarkPredicateForData: Option[BasePredicate] = + watermarkExpression.map(Predicate.create(_, child.output)) protected def removeKeysOlderThanWatermark(store: StateStore): Unit = { if (watermarkPredicateForKeys.nonEmpty) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 60a4638f610b3..d1b97b2852fbc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.sql.{sources, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedProjection, JoinedRow, Literal, Predicate} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.types.{DataType, StructType} @@ -88,7 +88,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { val attribute = inputAttributes.find(_.name == column).get expressions.GreaterThan(attribute, literal) }.reduceOption(expressions.And).getOrElse(Literal(true)) - InterpretedPredicate.create(filterCondition, inputAttributes) + Predicate.create(filterCondition, inputAttributes) } // Uses a simple projection to simulate column pruning