From 4b1e3119e998e68ef3f58a82fb20440e46dbd8a0 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 20 Nov 2019 11:06:51 +0900 Subject: [PATCH 1/7] Fix --- .../sql/catalyst/expressions/Projection.scala | 4 +- .../codegen/GeneratePredicate.scala | 22 ++--------- .../sql/catalyst/expressions/predicates.scala | 37 ++++++++++++++++++- .../sql/execution/DataSourceScanExec.scala | 2 +- .../spark/sql/execution/SparkPlan.scala | 25 +------------ .../execution/basicPhysicalOperators.scala | 3 +- .../columnar/InMemoryTableScanExec.scala | 2 +- .../joins/BroadcastNestedLoopJoinExec.scala | 4 +- .../joins/CartesianProductExec.scala | 5 +-- .../spark/sql/execution/joins/HashJoin.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 17 +++++---- .../streaming/statefulOperators.scala | 10 ++--- 13 files changed, 65 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 300f075d32763..b4a85e3e50bec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -89,14 +89,14 @@ object MutableProjection } /** - * Returns an MutableProjection for given sequence of bound Expressions. + * Returns a MutableProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): MutableProjection = { createObject(exprs) } /** - * Returns an MutableProjection for given sequence of Expressions, which will be bound to + * Returns a MutableProjection for given sequence of Expressions, which will be bound to * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index e0fabad6d089a..6ba646d360d2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -20,31 +20,17 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -/** - * Interface for generated predicate - */ -abstract class Predicate { - def eval(r: InternalRow): Boolean - - /** - * Initializes internal states given the current partition index. - * This is used by nondeterministic expressions to set initial states. - * The default implementation does nothing. - */ - def initialize(partitionIndex: Int): Unit = {} -} - /** * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]]. */ -object GeneratePredicate extends CodeGenerator[Expression, Predicate] { +object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = BindReferences.bindReference(in, inputSchema) - protected def create(predicate: Expression): Predicate = { + protected def create(predicate: Expression): BasePredicate = { val ctx = newCodeGenContext() val eval = predicate.genCode(ctx) @@ -53,7 +39,7 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { return new SpecificPredicate(references); } - class SpecificPredicate extends ${classOf[Predicate].getName} { + class SpecificPredicate extends ${classOf[BasePredicate].getName} { private final Object[] references; ${ctx.declareMutableStates()} @@ -79,6 +65,6 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") val (clazz, _) = CodeGenerator.compile(code) - clazz.generate(ctx.references.toArray).asInstanceOf[Predicate] + clazz.generate(ctx.references.toArray).asInstanceOf[BasePredicate] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4c0998412f729..b7844698c83ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,8 +21,9 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.TypeUtils @@ -30,6 +31,20 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +/** + * Interface for generated/interpreted predicate + */ +abstract class BasePredicate { + def eval(r: InternalRow): Boolean + + /** + * Initializes internal states given the current partition index. + * This is used by nondeterministic expressions to set initial states. + * The default implementation does nothing. + */ + def initialize(partitionIndex: Int): Unit = {} +} + object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = create(BindReferences.bindReference(expression, inputSchema)) @@ -56,6 +71,26 @@ trait Predicate extends Expression { override def dataType: DataType = BooleanType } +/** + * The factory object for `BasePredicate`. + */ +object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePredicate] { + + override protected def createCodeGeneratedObject(in: Expression): BasePredicate = { + GeneratePredicate.generate(in) + } + + override protected def createInterpretedObject(in: Expression): BasePredicate = { + InterpretedPredicate.create(in) + } + + /** + * Returns a BasePredicate for an Expression, which will be bound to `inputSchema`. + */ + def create(exprs: Expression, inputSchema: Seq[Attribute]): BasePredicate = { + createObject(bindReference(exprs, inputSchema)) + } +} trait PredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index b0fe4b741479f..88f5673aa9a1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -230,7 +230,7 @@ case class FileSourceScanExec( // call the file index for the files matching all filters except dynamic partition filters val predicate = dynamicPartitionFilters.reduce(And) val partitionColumns = relation.partitionSchema - val boundPredicate = newPredicate(predicate.transform { + val boundPredicate = Predicate.create(predicate.transform { case a: AttributeReference => val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 125f76282e3df..738af995376e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -21,7 +21,6 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.InternalCompilerException @@ -33,7 +32,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ @@ -471,28 +470,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ MutableProjection.create(expressions, inputSchema) } - private def genInterpretedPredicate( - expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = { - val str = expression.toString - val logMessage = if (str.length > 256) { - str.substring(0, 256 - 3) + "..." - } else { - str - } - logWarning(s"Codegen disabled for this expression:\n $logMessage") - InterpretedPredicate.create(expression, inputSchema) - } - - protected def newPredicate( - expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case _ @ (_: InternalCompilerException | _: CompileException) if codeGenFallBack => - genInterpretedPredicate(expression, inputSchema) - } - } - protected def newOrdering( order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = { GenerateOrdering.generate(order, inputSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 3ed42f359c0a4..e128d59dca6ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{LongType, StructType} @@ -227,7 +226,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithIndexInternal { (index, iter) => - val predicate = newPredicate(condition, child.output) + val predicate = Predicate.create(condition, child.output) predicate.initialize(0) iter.filter { row => val r = predicate.eval(row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 8d13cfb93d270..f03c2586048bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -310,7 +310,7 @@ case class InMemoryTableScanExec( val buffers = relation.cacheBuilder.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => - val partitionFilter = newPredicate( + val partitionFilter = Predicate.create( partitionFilters.reduceOption(And).getOrElse(Literal(true)), schema) partitionFilter.initialize(index) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index f526a19876670..5517c0dcdb188 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -19,14 +19,12 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.collection.{BitSet, CompactBuffer} case class BroadcastNestedLoopJoinExec( @@ -84,7 +82,7 @@ case class BroadcastNestedLoopJoinExec( @transient private lazy val boundCondition = { if (condition.isDefined) { - newPredicate(condition.get, streamed.output ++ broadcast.output).eval _ + Predicate.create(condition.get, streamed.output ++ broadcast.output).eval _ } else { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 88d98530991c9..29645a736548c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark._ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Predicate, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution.{BinaryExecNode, ExplainUtils, ExternalAppendOnlyUnsafeRowArray, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.CompletionIterator @@ -93,7 +92,7 @@ case class CartesianProductExec( pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { - val boundCondition = newPredicate(condition.get, left.output ++ right.output) + val boundCondition = Predicate.create(condition.get, left.output ++ right.output) boundCondition.initialize(index) val joined = new JoinedRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index e8938cb22e890..137f0b87a2f3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -99,7 +99,7 @@ trait HashJoin { UnsafeProjection.create(streamedKeys) @transient private[this] lazy val boundCondition = if (condition.isDefined) { - newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval _ + Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _ } else { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 26fb0e5ffb1af..cd3c596435a21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -168,7 +168,7 @@ case class SortMergeJoinExec( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => - newPredicate(cond, left.output ++ right.output).eval _ + Predicate.create(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 6bb4dc1672900..f1bfe97610fed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -233,8 +233,9 @@ case class StreamingSymmetricHashJoinExec( val joinedRow = new JoinedRow + val inputSchema = left.output ++ right.output val postJoinFilter = - newPredicate(condition.bothSides.getOrElse(Literal(true)), left.output ++ right.output).eval _ + Predicate.create(condition.bothSides.getOrElse(Literal(true)), inputSchema).eval _ val leftSideJoiner = new OneSideHashJoiner( LeftSide, left.output, leftKeys, leftInputIter, condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left) @@ -417,7 +418,7 @@ case class StreamingSymmetricHashJoinExec( // Filter the joined rows based on the given condition. val preJoinFilter = - newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ + Predicate.create(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ private val joinStateManager = new SymmetricHashJoinStateManager( joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value, @@ -428,16 +429,16 @@ case class StreamingSymmetricHashJoinExec( case Some(JoinStateKeyWatermarkPredicate(expr)) => // inputSchema can be empty as expr should only have BoundReferences and does not require // the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]]. - newPredicate(expr, Seq.empty).eval _ + Predicate.create(expr, Seq.empty).eval _ case _ => - newPredicate(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate + Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate } private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match { case Some(JoinStateValueWatermarkPredicate(expr)) => - newPredicate(expr, inputAttributes).eval _ + Predicate.create(expr, inputAttributes).eval _ case _ => - newPredicate(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate + Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate } private[this] var updatedStateRowsCount = 0 @@ -457,7 +458,7 @@ case class StreamingSymmetricHashJoinExec( val nonLateRows = WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match { case Some(watermarkExpr) => - val predicate = newPredicate(watermarkExpr, inputAttributes) + val predicate = Predicate.create(watermarkExpr, inputAttributes) inputIter.filter { row => !predicate.eval(row) } case None => inputIter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d689a6f3c9819..01b309c3cf345 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ @@ -156,17 +156,17 @@ trait WatermarkSupport extends UnaryExecNode { } /** Predicate based on keys that matches data older than the watermark */ - lazy val watermarkPredicateForKeys: Option[Predicate] = watermarkExpression.flatMap { e => + lazy val watermarkPredicateForKeys: Option[BasePredicate] = watermarkExpression.flatMap { e => if (keyExpressions.exists(_.metadata.contains(EventTimeWatermark.delayKey))) { - Some(newPredicate(e, keyExpressions)) + Some(Predicate.create(e, keyExpressions)) } else { None } } /** Predicate based on the child output that matches data older than the watermark. */ - lazy val watermarkPredicateForData: Option[Predicate] = - watermarkExpression.map(newPredicate(_, child.output)) + lazy val watermarkPredicateForData: Option[BasePredicate] = + watermarkExpression.map(Predicate.create(_, child.output)) protected def removeKeysOlderThanWatermark(store: StateStore): Unit = { if (watermarkPredicateForKeys.nonEmpty) { From 73588fcc2ad71df6530f629bcb2c84f9c531209c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 20 Nov 2019 14:40:56 +0900 Subject: [PATCH 2/7] Address reviews --- .../catalog/ExternalCatalogUtils.scala | 5 ++--- .../sql/catalyst/expressions/predicates.scala | 22 +++++++++---------- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../catalyst/expressions/PredicateSuite.scala | 2 +- .../PartitioningAwareFileIndex.scala | 2 +- .../sql/sources/SimpleTextRelation.scala | 4 ++-- 6 files changed, 18 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 4cff162c116a4..770952940c08c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI -import java.util.Locale import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell @@ -26,7 +25,7 @@ import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, Predicate} object ExternalCatalogUtils { // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't @@ -148,7 +147,7 @@ object ExternalCatalogUtils { } val boundPredicate = - InterpretedPredicate.create(predicates.reduce(And).transform { + Predicate.create(predicates.reduce(And).transform { case att: AttributeReference => val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b7844698c83ce..4cda69f5e0fa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ /** - * Interface for generated/interpreted predicate + * A base class for generated/interpreted predicate */ abstract class BasePredicate { def eval(r: InternalRow): Boolean @@ -45,13 +45,6 @@ abstract class BasePredicate { def initialize(partitionIndex: Int): Unit = {} } -object InterpretedPredicate { - def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = - create(BindReferences.bindReference(expression, inputSchema)) - - def create(expression: Expression): InterpretedPredicate = new InterpretedPredicate(expression) -} - case class InterpretedPredicate(expression: Expression) extends BasePredicate { override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] @@ -81,14 +74,21 @@ object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePr } override protected def createInterpretedObject(in: Expression): BasePredicate = { - InterpretedPredicate.create(in) + InterpretedPredicate(in) + } + + /** + * Returns a BasePredicate for a bound Expression. + */ + def create(expr: Expression): BasePredicate = { + create(expr) } /** * Returns a BasePredicate for an Expression, which will be bound to `inputSchema`. */ - def create(exprs: Expression, inputSchema: Seq[Attribute]): BasePredicate = { - createObject(bindReference(exprs, inputSchema)) + def create(expr: Expression, inputSchema: Seq[Attribute]): BasePredicate = { + createObject(bindReference(expr, inputSchema)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b78bdf082f333..f855596c363be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1507,7 +1507,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Filter(condition, LocalRelation(output, data, isStreaming)) if !hasUnevaluableExpr(condition) => - val predicate = InterpretedPredicate.create(condition, output) + val predicate = Predicate.create(condition, output) predicate.initialize(0) LocalRelation(output, data.filter(row => predicate.eval(row)), isStreaming) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 52cdd988caa2e..67a41e7cc2767 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -510,7 +510,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } test("Interpreted Predicate should initialize nondeterministic expressions") { - val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0))) + val interpreted = Predicate.create(LessThan(Rand(7), Literal(1.0))) interpreted.initialize(0) assert(interpreted.eval(new UnsafeRow())) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 3adec2f790730..2fa7c38c578f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -171,7 +171,7 @@ abstract class PartitioningAwareFileIndex( if (partitionPruningPredicates.nonEmpty) { val predicate = partitionPruningPredicates.reduce(expressions.And) - val boundPredicate = InterpretedPredicate.create(predicate.transform { + val boundPredicate = Predicate.create(predicate.transform { case a: AttributeReference => val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 60a4638f610b3..d1b97b2852fbc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.sql.{sources, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedProjection, JoinedRow, Literal, Predicate} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.types.{DataType, StructType} @@ -88,7 +88,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { val attribute = inputAttributes.find(_.name == column).get expressions.GreaterThan(attribute, literal) }.reduceOption(expressions.And).getOrElse(Literal(true)) - InterpretedPredicate.create(filterCondition, inputAttributes) + Predicate.create(filterCondition, inputAttributes) } // Uses a simple projection to simulate column pruning From dd3088f06e7213559317916675b62b1fab7aa9ad Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 20 Nov 2019 15:05:37 +0900 Subject: [PATCH 3/7] Revert "Address reviews" This reverts commit 73588fcc2ad71df6530f629bcb2c84f9c531209c. --- .../catalog/ExternalCatalogUtils.scala | 5 +++-- .../sql/catalyst/expressions/predicates.scala | 22 +++++++++---------- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../catalyst/expressions/PredicateSuite.scala | 2 +- .../PartitioningAwareFileIndex.scala | 2 +- .../sql/sources/SimpleTextRelation.scala | 4 ++-- 6 files changed, 19 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 770952940c08c..4cff162c116a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI +import java.util.Locale import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell @@ -25,7 +26,7 @@ import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, Predicate} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate} object ExternalCatalogUtils { // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't @@ -147,7 +148,7 @@ object ExternalCatalogUtils { } val boundPredicate = - Predicate.create(predicates.reduce(And).transform { + InterpretedPredicate.create(predicates.reduce(And).transform { case att: AttributeReference => val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4cda69f5e0fa1..b7844698c83ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ /** - * A base class for generated/interpreted predicate + * Interface for generated/interpreted predicate */ abstract class BasePredicate { def eval(r: InternalRow): Boolean @@ -45,6 +45,13 @@ abstract class BasePredicate { def initialize(partitionIndex: Int): Unit = {} } +object InterpretedPredicate { + def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = + create(BindReferences.bindReference(expression, inputSchema)) + + def create(expression: Expression): InterpretedPredicate = new InterpretedPredicate(expression) +} + case class InterpretedPredicate(expression: Expression) extends BasePredicate { override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] @@ -74,21 +81,14 @@ object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePr } override protected def createInterpretedObject(in: Expression): BasePredicate = { - InterpretedPredicate(in) - } - - /** - * Returns a BasePredicate for a bound Expression. - */ - def create(expr: Expression): BasePredicate = { - create(expr) + InterpretedPredicate.create(in) } /** * Returns a BasePredicate for an Expression, which will be bound to `inputSchema`. */ - def create(expr: Expression, inputSchema: Seq[Attribute]): BasePredicate = { - createObject(bindReference(expr, inputSchema)) + def create(exprs: Expression, inputSchema: Seq[Attribute]): BasePredicate = { + createObject(bindReference(exprs, inputSchema)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f855596c363be..b78bdf082f333 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1507,7 +1507,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Filter(condition, LocalRelation(output, data, isStreaming)) if !hasUnevaluableExpr(condition) => - val predicate = Predicate.create(condition, output) + val predicate = InterpretedPredicate.create(condition, output) predicate.initialize(0) LocalRelation(output, data.filter(row => predicate.eval(row)), isStreaming) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 67a41e7cc2767..52cdd988caa2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -510,7 +510,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } test("Interpreted Predicate should initialize nondeterministic expressions") { - val interpreted = Predicate.create(LessThan(Rand(7), Literal(1.0))) + val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0))) interpreted.initialize(0) assert(interpreted.eval(new UnsafeRow())) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 2fa7c38c578f4..3adec2f790730 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -171,7 +171,7 @@ abstract class PartitioningAwareFileIndex( if (partitionPruningPredicates.nonEmpty) { val predicate = partitionPruningPredicates.reduce(expressions.And) - val boundPredicate = Predicate.create(predicate.transform { + val boundPredicate = InterpretedPredicate.create(predicate.transform { case a: AttributeReference => val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index d1b97b2852fbc..60a4638f610b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.sql.{sources, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedProjection, JoinedRow, Literal, Predicate} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.types.{DataType, StructType} @@ -88,7 +88,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { val attribute = inputAttributes.find(_.name == column).get expressions.GreaterThan(attribute, literal) }.reduceOption(expressions.And).getOrElse(Literal(true)) - Predicate.create(filterCondition, inputAttributes) + InterpretedPredicate.create(filterCondition, inputAttributes) } // Uses a simple projection to simulate column pruning From bed96f4b8be74f697eca0571bce24b8d074a7ee3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 20 Nov 2019 15:08:16 +0900 Subject: [PATCH 4/7] Address comments --- .../apache/spark/sql/catalyst/expressions/predicates.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b7844698c83ce..6c3b2ec13c3da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ /** - * Interface for generated/interpreted predicate + * A base class for generated/interpreted predicate */ abstract class BasePredicate { def eval(r: InternalRow): Boolean @@ -81,7 +81,7 @@ object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePr } override protected def createInterpretedObject(in: Expression): BasePredicate = { - InterpretedPredicate.create(in) + InterpretedPredicate(in) } /** From 0896afe8dd31595bd88465fc3959d35f9b692908 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 20 Nov 2019 15:10:47 +0900 Subject: [PATCH 5/7] Address more comments --- .../catalyst/catalog/ExternalCatalogUtils.scala | 5 ++--- .../sql/catalyst/expressions/predicates.scala | 16 +++++++--------- .../spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- .../catalyst/expressions/PredicateSuite.scala | 2 +- .../datasources/PartitioningAwareFileIndex.scala | 2 +- .../spark/sql/sources/SimpleTextRelation.scala | 4 ++-- 6 files changed, 14 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 4cff162c116a4..c47e48dc989bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI -import java.util.Locale import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell @@ -26,7 +25,7 @@ import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, Predicate} object ExternalCatalogUtils { // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't @@ -148,7 +147,7 @@ object ExternalCatalogUtils { } val boundPredicate = - InterpretedPredicate.create(predicates.reduce(And).transform { + Predicate.createInterpretedPredicate(predicates.reduce(And).transform { case att: AttributeReference => val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 6c3b2ec13c3da..e0c3291fe24a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -45,13 +45,6 @@ abstract class BasePredicate { def initialize(partitionIndex: Int): Unit = {} } -object InterpretedPredicate { - def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = - create(BindReferences.bindReference(expression, inputSchema)) - - def create(expression: Expression): InterpretedPredicate = new InterpretedPredicate(expression) -} - case class InterpretedPredicate(expression: Expression) extends BasePredicate { override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] @@ -84,11 +77,16 @@ object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePr InterpretedPredicate(in) } + def createInterpretedPredicate(e: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = + createInterpretedPredicate(BindReferences.bindReference(e, inputSchema)) + + def createInterpretedPredicate(e: Expression): InterpretedPredicate = InterpretedPredicate(e) + /** * Returns a BasePredicate for an Expression, which will be bound to `inputSchema`. */ - def create(exprs: Expression, inputSchema: Seq[Attribute]): BasePredicate = { - createObject(bindReference(exprs, inputSchema)) + def create(e: Expression, inputSchema: Seq[Attribute]): BasePredicate = { + createObject(bindReference(e, inputSchema)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b78bdf082f333..f855596c363be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1507,7 +1507,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Filter(condition, LocalRelation(output, data, isStreaming)) if !hasUnevaluableExpr(condition) => - val predicate = InterpretedPredicate.create(condition, output) + val predicate = Predicate.create(condition, output) predicate.initialize(0) LocalRelation(output, data.filter(row => predicate.eval(row)), isStreaming) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 52cdd988caa2e..5f87911db71ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -510,7 +510,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } test("Interpreted Predicate should initialize nondeterministic expressions") { - val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0))) + val interpreted = Predicate.createInterpretedPredicate(LessThan(Rand(7), Literal(1.0))) interpreted.initialize(0) assert(interpreted.eval(new UnsafeRow())) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 3adec2f790730..427fffa027d23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -171,7 +171,7 @@ abstract class PartitioningAwareFileIndex( if (partitionPruningPredicates.nonEmpty) { val predicate = partitionPruningPredicates.reduce(expressions.And) - val boundPredicate = InterpretedPredicate.create(predicate.transform { + val boundPredicate = Predicate.createInterpretedPredicate(predicate.transform { case a: AttributeReference => val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 60a4638f610b3..e9310e125d32f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.sql.{sources, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedProjection, JoinedRow, Literal, Predicate} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.types.{DataType, StructType} @@ -88,7 +88,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { val attribute = inputAttributes.find(_.name == column).get expressions.GreaterThan(attribute, literal) }.reduceOption(expressions.And).getOrElse(Literal(true)) - InterpretedPredicate.create(filterCondition, inputAttributes) + Predicate.createInterpretedPredicate(filterCondition, inputAttributes) } // Uses a simple projection to simulate column pruning From 74183e6e1a1ce930f07979678d494224b6bd41ad Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 20 Nov 2019 16:43:25 +0900 Subject: [PATCH 6/7] Fix --- .../sql/catalyst/catalog/ExternalCatalogUtils.scala | 2 +- .../spark/sql/catalyst/expressions/predicates.scala | 13 ++++++++++--- .../sql/catalyst/expressions/PredicateSuite.scala | 2 +- .../datasources/PartitioningAwareFileIndex.scala | 2 +- .../spark/sql/sources/SimpleTextRelation.scala | 2 +- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index c47e48dc989bf..ae3b75dc3334b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -147,7 +147,7 @@ object ExternalCatalogUtils { } val boundPredicate = - Predicate.createInterpretedPredicate(predicates.reduce(And).transform { + Predicate.createInterpreted(predicates.reduce(And).transform { case att: AttributeReference => val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e0c3291fe24a1..51e1646dde4d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -77,10 +77,10 @@ object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePr InterpretedPredicate(in) } - def createInterpretedPredicate(e: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = - createInterpretedPredicate(BindReferences.bindReference(e, inputSchema)) + def createInterpreted(e: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = + createInterpreted(BindReferences.bindReference(e, inputSchema)) - def createInterpretedPredicate(e: Expression): InterpretedPredicate = InterpretedPredicate(e) + def createInterpreted(e: Expression): InterpretedPredicate = InterpretedPredicate(e) /** * Returns a BasePredicate for an Expression, which will be bound to `inputSchema`. @@ -88,6 +88,13 @@ object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePr def create(e: Expression, inputSchema: Seq[Attribute]): BasePredicate = { createObject(bindReference(e, inputSchema)) } + + /** + * Returns a BasePredicate for a given bound Expression. + */ + def create(e: Expression): BasePredicate = { + createObject(e) + } } trait PredicateHelper { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 5f87911db71ce..67a41e7cc2767 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -510,7 +510,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } test("Interpreted Predicate should initialize nondeterministic expressions") { - val interpreted = Predicate.createInterpretedPredicate(LessThan(Rand(7), Literal(1.0))) + val interpreted = Predicate.create(LessThan(Rand(7), Literal(1.0))) interpreted.initialize(0) assert(interpreted.eval(new UnsafeRow())) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 427fffa027d23..21ddeb6491155 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -171,7 +171,7 @@ abstract class PartitioningAwareFileIndex( if (partitionPruningPredicates.nonEmpty) { val predicate = partitionPruningPredicates.reduce(expressions.And) - val boundPredicate = Predicate.createInterpretedPredicate(predicate.transform { + val boundPredicate = Predicate.createInterpreted(predicate.transform { case a: AttributeReference => val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index e9310e125d32f..d1b97b2852fbc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -88,7 +88,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { val attribute = inputAttributes.find(_.name == column).get expressions.GreaterThan(attribute, literal) }.reduceOption(expressions.And).getOrElse(Literal(true)) - Predicate.createInterpretedPredicate(filterCondition, inputAttributes) + Predicate.create(filterCondition, inputAttributes) } // Uses a simple projection to simulate column pruning From 164507ce26ddcf20e9c970cccf7746cc78a6119d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 20 Nov 2019 16:50:31 +0900 Subject: [PATCH 7/7] Fix --- .../org/apache/spark/sql/catalyst/expressions/predicates.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 51e1646dde4d0..bcd442ad3cc35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -77,9 +77,6 @@ object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePr InterpretedPredicate(in) } - def createInterpreted(e: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = - createInterpreted(BindReferences.bindReference(e, inputSchema)) - def createInterpreted(e: Expression): InterpretedPredicate = InterpretedPredicate(e) /**