From 62c569672083c0fa633da1d6edaba40d0bb05819 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 17 Jan 2018 13:58:12 -0800 Subject: [PATCH 1/5] SPARK-22386: DataSourceV2: Use immutable logical plans. --- .../apache/spark/sql/DataFrameReader.scala | 45 ++-- .../datasources/v2/DataSourceV2Relation.scala | 203 ++++++++++++++++-- .../datasources/v2/DataSourceV2Strategy.scala | 7 +- .../datasources/v2/DataSourceV2Utils.scala | 30 +++ .../v2/PushDownOperatorsToDataSource.scala | 146 ++++--------- .../continuous/ContinuousExecution.scala | 2 +- .../sql/sources/v2/DataSourceV2Suite.scala | 2 +- .../sources/v2/DataSourceV2UtilsSuite.scala | 33 ++- .../spark/sql/streaming/StreamTest.scala | 6 +- 9 files changed, 316 insertions(+), 158 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 46b5f54a33f74..b1a3fcf44c2d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -185,39 +185,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() - val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = sparkSession.sessionState.conf)).asJava) - - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading - // the dataframe as a v1 source. - val reader = (ds, userSpecifiedSchema) match { - case (ds: ReadSupportWithSchema, Some(schema)) => - ds.createReader(schema, options) - - case (ds: ReadSupport, None) => - ds.createReader(options) - - case (ds: ReadSupportWithSchema, None) => - throw new AnalysisException(s"A schema needs to be specified when using $ds.") - - case (ds: ReadSupport, Some(schema)) => - val reader = ds.createReader(options) - if (reader.readSchema() != schema) { - throw new AnalysisException(s"$ds does not allow user-specified schemas.") - } - reader - - case _ => null // fall back to v1 - } - - if (reader == null) { - loadV1Source(paths: _*) + val ds = cls.newInstance().asInstanceOf[DataSourceV2] + val (pathOption, tableOption) = DataSourceV2Utils.parseTableLocation( + sparkSession, extraOptions.get("path")) + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = ds, conf = sparkSession.sessionState.conf) + + if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { + Dataset.ofRows(sparkSession, DataSourceV2Relation( + ds, extraOptions.toMap ++ sessionOptions, pathOption, tableOption, + userSchema = userSpecifiedSchema)) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + loadV1Source(paths: _*) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index eebfa29f91b99..04f1df6a26324 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,17 +17,151 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.UUID + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema, WriteSupport} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter +import org.apache.spark.sql.types.StructType case class DataSourceV2Relation( - fullOutput: Seq[AttributeReference], - reader: DataSourceReader) - extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { + source: DataSourceV2, + options: Map[String, String], + path: Option[String] = None, + table: Option[TableIdentifier] = None, + projection: Option[Seq[AttributeReference]] = None, + filters: Option[Seq[Expression]] = None, + userSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation { + + override def simpleString: String = { + "DataSourceV2Relation(" + + s"source=$sourceName${path.orElse(table).map(loc => s"($loc)").getOrElse("")}, " + + s"schema=[${output.map(a => s"$a ${a.dataType.simpleString}").mkString(", ")}], " + + s"filters=[${pushedFilters.mkString(", ")}] options=$options)" + } + + override lazy val schema: StructType = reader.readSchema() + + override lazy val output: Seq[AttributeReference] = { + projection match { + case Some(attrs) => + // use the projection attributes to avoid assigning new ids. fields that are not projected + // will be assigned new ids, which is okay because they are not projected. + val attrMap = attrs.map(a => a.name -> a).toMap + schema.map(f => attrMap.getOrElse(f.name, + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) + case _ => + schema.toAttributes + } + } + + private lazy val v2Options: DataSourceOptions = { + // ensure path and table options are set correctly + val updatedOptions = new mutable.HashMap[String, String] + updatedOptions ++= options + + path match { + case Some(p) => + updatedOptions.put("path", p) + case None => + updatedOptions.remove("path") + } + + table.map { ident => + updatedOptions.put("table", ident.table) + ident.database match { + case Some(db) => + updatedOptions.put("database", db) + case None => + updatedOptions.remove("database") + } + } + + new DataSourceOptions(options.asJava) + } + + private val sourceName: String = { + source match { + case registered: DataSourceRegister => + registered.shortName() + case _ => + source.getClass.getSimpleName + } + } + + lazy val ( + reader: DataSourceReader, + unsupportedFilters: Seq[Expression], + pushedFilters: Seq[Expression]) = { + val newReader = userSchema match { + case Some(s) => + asReadSupportWithSchema.createReader(s, v2Options) + case _ => + asReadSupport.createReader(v2Options) + } + + projection.foreach { attrs => + DataSourceV2Relation.pushRequiredColumns(newReader, attrs.toStructType) + } + + val (remainingFilters, pushedFilters) = filters match { + case Some(filterSeq) => + DataSourceV2Relation.pushFilters(newReader, filterSeq) + case _ => + (Nil, Nil) + } - override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] + (newReader, remainingFilters, pushedFilters) + } + + def writer(dfSchema: StructType, mode: SaveMode): Option[DataSourceWriter] = { + val writer = asWriteSupport.createWriter(UUID.randomUUID.toString, dfSchema, mode, v2Options) + if (writer.isPresent) Some(writer.get()) else None + } + + private lazy val asReadSupport: ReadSupport = { + source match { + case support: ReadSupport => + support + case _: ReadSupportWithSchema => + // this method is only called if there is no user-supplied schema. if there is no + // user-supplied schema and ReadSupport was not implemented, throw a helpful exception. + throw new AnalysisException(s"Data source requires a user-supplied schema: $sourceName") + case _ => + throw new AnalysisException(s"Data source is not readable: $sourceName") + } + } + + private lazy val asReadSupportWithSchema: ReadSupportWithSchema = { + source match { + case support: ReadSupportWithSchema => + support + case _: ReadSupport => + throw new AnalysisException( + s"Data source does not support user-supplied schema: $sourceName") + case _ => + throw new AnalysisException(s"Data source is not readable: $sourceName") + } + } + + private lazy val asWriteSupport: WriteSupport = { + source match { + case support: WriteSupport => + support + case _ => + throw new AnalysisException(s"Data source is not writable: $sourceName") + } + } override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => @@ -37,7 +171,9 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - copy(fullOutput = fullOutput.map(_.newInstance())) + // projection is used to maintain id assignment. + // if projection is not set, use output so the copy is not equal to the original + copy(projection = Some(projection.getOrElse(output).map(_.newInstance()))) } } @@ -45,14 +181,57 @@ case class DataSourceV2Relation( * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical * to the non-streaming relation. */ -class StreamingDataSourceV2Relation( +case class StreamingDataSourceV2Relation( fullOutput: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) { + reader: DataSourceReader) + extends LeafNode with DataSourceReaderHolder with MultiInstanceRelation { override def isStreaming: Boolean = true + + override def canEqual(other: Any): Boolean = other.isInstanceOf[StreamingDataSourceV2Relation] + + override def newInstance(): LogicalPlan = copy(fullOutput = fullOutput.map(_.newInstance())) } object DataSourceV2Relation { - def apply(reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { + reader match { + case projectionSupport: SupportsPushDownRequiredColumns => + projectionSupport.pruneColumns(struct) + case _ => + } + } + + private def pushFilters( + reader: DataSourceReader, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + reader match { + case catalystFilterSupport: SupportsPushDownCatalystFilters => + ( + catalystFilterSupport.pushCatalystFilters(filters.toArray), + catalystFilterSupport.pushedCatalystFilters() + ) + + case filterSupport: SupportsPushDownFilters => + // A map from original Catalyst expressions to corresponding translated data source + // filters. If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, Filter] = filters.flatMap { p => + DataSourceStrategy.translateFilter(p).map(f => p -> f) + }.toMap + + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) + + // Data source filters that cannot be pushed down. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet + val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) => + unhandledFilters.contains(f) + } + + (nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq) + + case _ => (filters, Nil) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index df5b524485f54..5c992866df6b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,11 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case DataSourceV2Relation(output, reader) => - DataSourceV2ScanExec(output, reader) :: Nil + case relation: DataSourceV2Relation => + DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + + case relation: StreamingDataSourceV2Relation => + DataSourceV2ScanExec(relation.fullOutput, relation.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 5267f5f1580c3..e028c8016d915 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.regex.Pattern import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} @@ -55,4 +57,32 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => Map.empty } + + /** + * Helper method to parse the argument passed to load or save. If the path doesn't contain '/' + * and cannot be a fully-qualified location, parse it as a table identifier. Otherwise, return + * the path. + * + * @param sparkSession a [[SparkSession]] + * @param pathOrTable some string passed to load or save, or None + * @return + */ + def parseTableLocation( + sparkSession: SparkSession, + pathOrTable: Option[String]): (Option[String], Option[TableIdentifier]) = { + pathOrTable match { + case Some(path) if !path.contains("/") => + // without "/", this cannot be a full path. parse it as a table name + val ident = sparkSession.sessionState.sqlParser.parseTableIdentifier(path) + // ensure the database is set correctly + val db = ident.database.getOrElse(sparkSession.catalog.currentDatabase) + (None, Some(ident.copy(database = Some(db)))) + + case Some(path) => + (Some(path), None) + + case _ => + (None, None) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 566a48394f02e..45711925092f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -17,119 +17,55 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper} -import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources -import org.apache.spark.sql.sources.v2.reader._ - -/** - * Pushes down various operators to the underlying data source for better performance. Operators are - * being pushed down with a specific order. As an example, given a LIMIT has a FILTER child, you - * can't push down LIMIT if FILTER is not completely pushed down. When both are pushed down, the - * data source should execute FILTER before LIMIT. And required columns are calculated at the end, - * because when more operators are pushed down, we may need less columns at Spark side. - */ -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHelper { - override def apply(plan: LogicalPlan): LogicalPlan = { - // Note that, we need to collect the target operator along with PROJECT node, as PROJECT may - // appear in many places for column pruning. - // TODO: Ideally column pruning should be implemented via a plan property that is propagated - // top-down, then we can simplify the logic here and only collect target operators. - val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => - val (candidates, nonDeterministic) = - splitConjunctivePredicates(condition).partition(_.deterministic) - - val stayUpFilters: Seq[Expression] = reader match { - case r: SupportsPushDownCatalystFilters => - r.pushCatalystFilters(candidates.toArray) - - case r: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, sources.Filter] = candidates.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = candidates.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet - val unhandledPredicates = translatedMap.filter { case (_, f) => - unhandledFilters.contains(f) - }.keys - - nonConvertiblePredicates ++ unhandledPredicates - - case _ => candidates - } - - val filterCondition = (stayUpFilters ++ nonDeterministic).reduceLeftOption(And) - val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) - if (withFilter.output == fields) { - withFilter - } else { - Project(fields, withFilter) - } - } - - // TODO: add more push down rules. - - pushDownRequiredColumns(filterPushed, filterPushed.outputSet) - // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(filterPushed) - } - - // TODO: nested fields pruning - private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = { - plan match { - case Project(projectList, child) => - val required = projectList.flatMap(_.references) - pushDownRequiredColumns(child, AttributeSet(required)) - - case Filter(condition, child) => - val required = requiredByParent ++ condition.references - pushDownRequiredColumns(child, required) - - case relation: DataSourceV2Relation => relation.reader match { - case reader: SupportsPushDownRequiredColumns => - val requiredColumns = relation.output.filter(requiredByParent.contains) - reader.pruneColumns(requiredColumns.toStructType) +object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { + override def apply( + plan: LogicalPlan): LogicalPlan = plan transformUp { + // PhysicalOperation guarantees that filters are deterministic; no need to check + case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) => + // merge the filters + val filters = relation.filters match { + case Some(existing) => + existing ++ newFilters case _ => + newFilters } - // TODO: there may be more operators that can be used to calculate the required columns. We - // can add more and more in the future. - case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet)) - } - } - - /** - * Finds a Filter node(with an optional Project child) above data source relation. - */ - object FilterAndProject { - // returns the project list, the filter condition and the data source relation. - def unapply(plan: LogicalPlan) - : Option[(Seq[NamedExpression], Expression, DataSourceV2Relation)] = plan match { + val projectAttrs = project.map(_.toAttribute) + val projectSet = AttributeSet(project.flatMap(_.references)) + val filterSet = AttributeSet(filters.flatMap(_.references)) + + val projection = if (filterSet.subsetOf(projectSet) && + AttributeSet(projectAttrs) == projectSet) { + // When the required projection contains all of the filter columns and column pruning alone + // can produce the required projection, push the required projection. + // A final projection may still be needed if the data source produces a different column + // order or if it cannot prune all of the nested columns. + projectAttrs + } else { + // When there are filter columns not already in the required projection or when the required + // projection is more complicated than column pruning, base column pruning on the set of + // all columns needed by both. + (projectSet ++ filterSet).toSeq + } - case Filter(condition, r: DataSourceV2Relation) => Some((r.output, condition, r)) + val newRelation = relation.copy( + projection = Some(projection.asInstanceOf[Seq[AttributeReference]]), + filters = Some(filters)) - case Filter(condition, Project(fields, r: DataSourceV2Relation)) - if fields.forall(_.deterministic) => - val attributeMap = AttributeMap(fields.map(e => e.toAttribute -> e)) - val substituted = condition.transform { - case a: Attribute => attributeMap.getOrElse(a, a) - } - Some((fields, substituted, r)) + // Add a Filter for any filters that could not be pushed + val unpushedFilter = newRelation.unsupportedFilters.reduceLeftOption(And) + val filtered = unpushedFilter.map(Filter(_, newRelation)).getOrElse(newRelation) - case _ => None - } + // Add a Project to ensure the output matches the required projection + if (newRelation.output != projectAttrs) { + Project(project, filtered) + } else { + filtered + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 08c81419a9d34..01007b3db5747 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -203,7 +203,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case DataSourceV2Relation(_, r: ContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index eccd45442a3b2..af629fb9038b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -146,7 +146,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { val e = intercept[AnalysisException](spark.read.format(cls.getName).load()) - assert(e.message.contains("A schema needs to be specified")) + assert(e.message.contains("requires a user-supplied schema")) val schema = new StructType().add("i", "int").add("s", "string") val df = spark.read.format(cls.getName).schema(schema).load() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala index 4911e3225552d..9a7cadb207be7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.sources.v2 +import java.net.URI + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogDatabase import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext -class DataSourceV2UtilsSuite extends SparkFunSuite { +class DataSourceV2UtilsSuite extends SparkFunSuite with SharedSQLContext { private val keyPrefix = new DataSourceV2WithSessionConfig().keyPrefix @@ -41,6 +46,32 @@ class DataSourceV2UtilsSuite extends SparkFunSuite { assert(confs.keySet.contains("foo.bar")) assert(confs.keySet.contains("whateverConfigName")) } + + test("parseTableLocation") { + import DataSourceV2Utils.parseTableLocation + // no location + assert((None, None) === parseTableLocation(spark, None)) + + // file paths + val s3Path = "s3://bucket/path/file.ext" + assert((Some(s3Path), None) === parseTableLocation(spark, Some(s3Path))) + val hdfsPath = "hdfs://nn:8020/path/file.ext" + assert((Some(hdfsPath), None) === parseTableLocation(spark, Some(hdfsPath))) + val localPath = "/path/file.ext" + assert((Some(localPath), None) === parseTableLocation(spark, Some(localPath))) + + // table names + assert( + (None, Some(TableIdentifier("t", Some("default")))) === parseTableLocation(spark, Some("t"))) + assert( + (None, Some(TableIdentifier("t", Some("db")))) === parseTableLocation(spark, Some("db.t"))) + + spark.sessionState.catalog.createDatabase( + CatalogDatabase("test", "test", URI.create("file:/tmp"), Map.empty), ignoreIfExists = true) + spark.sessionState.catalog.setCurrentDatabase("test") + assert( + (None, Some(TableIdentifier("t", Some("test")))) === parseTableLocation(spark, Some("t"))) + } } class DataSourceV2WithSessionConfig extends SimpleDataSourceV2 with SessionConfigSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d6433562fb29b..3b787912a5e00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -35,10 +35,10 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder, encoderFor} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r + case StreamingDataSourceV2Relation(_, r) => r } .zipWithIndex .find(_._1 == source) From f0bd45d3c931941b8092cdac738cb29954e0acdd Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 24 Jan 2018 11:34:42 -0800 Subject: [PATCH 2/5] SPARK-23203: Fix scala style check. --- .../scala/org/apache/spark/sql/streaming/StreamTest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 3b787912a5e00..06235790331d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -35,12 +35,12 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation} +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ From 2fdeb4556cd22a092630b341a22a16a59e377183 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 24 Jan 2018 11:54:10 -0800 Subject: [PATCH 3/5] SPARK-23203: Fix Kafka tests, use StreamingDataSourceV2Relation. This also removes unused imports. --- .../kafka010/KafkaContinuousSourceSuite.scala | 19 ++++--------------- .../sql/kafka010/KafkaContinuousTest.scala | 4 ++-- .../spark/sql/kafka010/KafkaSourceSuite.scala | 4 ++-- 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index a7083fa4e3417..f679e9bfc0450 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,20 +17,9 @@ package org.apache.spark.sql.kafka010 -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -71,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 5a1a14f7a307a..48ac3fc1e8f9d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 02c87643568bd..d26beca800bcc 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,7 +34,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext import org.apache.spark.sql.{Dataset, ForeachWriter} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} @@ -117,7 +117,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader } }) if (sources.isEmpty) { From ab945a19efe666c41deae9c044002f3455220c1d Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 2 Feb 2018 12:30:33 -0800 Subject: [PATCH 4/5] SPARK-23204: DataFrameReader: Remove v2 table identifier parsing. --- .../apache/spark/sql/DataFrameReader.scala | 10 +++--- .../datasources/v2/DataSourceV2Utils.scala | 28 ---------------- .../sources/v2/DataSourceV2UtilsSuite.scala | 33 +------------------ 3 files changed, 5 insertions(+), 66 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b1a3fcf44c2d6..8705888986f50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -186,15 +186,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance().asInstanceOf[DataSourceV2] - val (pathOption, tableOption) = DataSourceV2Utils.parseTableLocation( - sparkSession, extraOptions.get("path")) - val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = ds, conf = sparkSession.sessionState.conf) - if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = ds, conf = sparkSession.sessionState.conf) Dataset.ofRows(sparkSession, DataSourceV2Relation( - ds, extraOptions.toMap ++ sessionOptions, pathOption, tableOption, + ds, extraOptions.toMap ++ sessionOptions, path = extraOptions.get("path"), userSchema = userSpecifiedSchema)) + } else { loadV1Source(paths: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index e028c8016d915..a506aef891535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -57,32 +57,4 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => Map.empty } - - /** - * Helper method to parse the argument passed to load or save. If the path doesn't contain '/' - * and cannot be a fully-qualified location, parse it as a table identifier. Otherwise, return - * the path. - * - * @param sparkSession a [[SparkSession]] - * @param pathOrTable some string passed to load or save, or None - * @return - */ - def parseTableLocation( - sparkSession: SparkSession, - pathOrTable: Option[String]): (Option[String], Option[TableIdentifier]) = { - pathOrTable match { - case Some(path) if !path.contains("/") => - // without "/", this cannot be a full path. parse it as a table name - val ident = sparkSession.sessionState.sqlParser.parseTableIdentifier(path) - // ensure the database is set correctly - val db = ident.database.getOrElse(sparkSession.catalog.currentDatabase) - (None, Some(ident.copy(database = Some(db)))) - - case Some(path) => - (Some(path), None) - - case _ => - (None, None) - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala index 9a7cadb207be7..4911e3225552d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -17,16 +17,11 @@ package org.apache.spark.sql.sources.v2 -import java.net.URI - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogDatabase import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext -class DataSourceV2UtilsSuite extends SparkFunSuite with SharedSQLContext { +class DataSourceV2UtilsSuite extends SparkFunSuite { private val keyPrefix = new DataSourceV2WithSessionConfig().keyPrefix @@ -46,32 +41,6 @@ class DataSourceV2UtilsSuite extends SparkFunSuite with SharedSQLContext { assert(confs.keySet.contains("foo.bar")) assert(confs.keySet.contains("whateverConfigName")) } - - test("parseTableLocation") { - import DataSourceV2Utils.parseTableLocation - // no location - assert((None, None) === parseTableLocation(spark, None)) - - // file paths - val s3Path = "s3://bucket/path/file.ext" - assert((Some(s3Path), None) === parseTableLocation(spark, Some(s3Path))) - val hdfsPath = "hdfs://nn:8020/path/file.ext" - assert((Some(hdfsPath), None) === parseTableLocation(spark, Some(hdfsPath))) - val localPath = "/path/file.ext" - assert((Some(localPath), None) === parseTableLocation(spark, Some(localPath))) - - // table names - assert( - (None, Some(TableIdentifier("t", Some("default")))) === parseTableLocation(spark, Some("t"))) - assert( - (None, Some(TableIdentifier("t", Some("db")))) === parseTableLocation(spark, Some("db.t"))) - - spark.sessionState.catalog.createDatabase( - CatalogDatabase("test", "test", URI.create("file:/tmp"), Map.empty), ignoreIfExists = true) - spark.sessionState.catalog.setCurrentDatabase("test") - assert( - (None, Some(TableIdentifier("t", Some("test")))) === parseTableLocation(spark, Some("t"))) - } } class DataSourceV2WithSessionConfig extends SimpleDataSourceV2 with SessionConfigSupport { From 3580daf15497a1d49112a0eddd556f74b9b3e280 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 2 Feb 2018 11:04:23 -0800 Subject: [PATCH 5/5] SPARK-23321: Apply preprocess insert rules to DataSourceV2. This updates the DataSourceV2 write path to use DataSourceV2Relation and InsertIntoTable to apply the insert preprocess rules. --- .../plans/logical/basicLogicalOperators.scala | 4 +-- .../apache/spark/sql/DataFrameWriter.scala | 32 ++++++++----------- .../sql/execution/command/commands.scala | 6 ++++ .../datasources/DataSourceStrategy.scala | 9 ++++++ .../sql/execution/datasources/rules.scala | 4 +++ .../datasources/v2/DataSourceV2Strategy.scala | 6 +++- .../datasources/v2/DataSourceV2Utils.scala | 29 +++++++++++++++-- .../sql/sources/v2/DataSourceV2Suite.scala | 16 ++++++++++ 8 files changed, 82 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a4fca790dd086..ffd6b30b8cf93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -368,8 +368,8 @@ case class InsertIntoTable( overwrite: Boolean, ifPartitionNotExists: Boolean) extends LogicalPlan { - // IF NOT EXISTS is only valid in INSERT OVERWRITE - assert(overwrite || !ifPartitionNotExists) + // overwrite=false and ifPartitionNotExists=false are used to pass mode=Ignore + // IF NOT EXISTS is only valid in static partitions assert(partition.values.forall(_.nonEmpty) || !ifPartitionNotExists) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ed7a9100cc7f1..8a3f0a0aa0981 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql -import java.text.SimpleDateFormat -import java.util.{Date, Locale, Properties, UUID} +import java.util.{Locale, Properties} import scala.collection.JavaConverters._ @@ -30,8 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, InsertIntoT import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.StructType @@ -240,22 +238,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() + val ds = cls.newInstance().asInstanceOf[DataSourceV2] ds match { - case ws: WriteSupport => - val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = df.sparkSession.sessionState.conf)).asJava) - // Using a timestamp and a random UUID to distinguish different writing jobs. This is good - // enough as there won't be tons of writing jobs created at the same second. - val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) - .format(new Date()) + "-" + UUID.randomUUID() - val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options) - if (writer.isPresent) { - runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(writer.get(), df.logicalPlan) - } + case _: WriteSupport => + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = ds, conf = df.sparkSession.sessionState.conf) + val relation = DataSourceV2Relation( + ds, extraOptions.toMap ++ sessionOptions, path = extraOptions.get("path")) + + val (overwrite, ifNotExists) = DataSourceV2Utils.overwriteAndIfNotExists(mode) + + runCommand(df.sparkSession, "save") { + InsertIntoTable(relation, Map.empty, df.logicalPlan, overwrite, ifNotExists) } // Streaming also uses the data source V2 API. So it may be that the data source implements diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 2cc0e38adc2ee..853066e362500 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -46,6 +46,12 @@ trait RunnableCommand extends Command { def run(sparkSession: SparkSession): Seq[Row] } +case class NoopCommand() extends RunnableCommand { + override def run(sparkSession: SparkSession): Seq[Row] = { + Seq.empty[Row] + } +} + /** * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d94c5bbccdd84..a5bad87d69110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, WriteToDataSourceV2} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -145,6 +146,14 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast parts, query, overwrite, false) if parts.isEmpty => InsertIntoDataSourceCommand(l, query, overwrite) + case InsertIntoTable(rel: DataSourceV2Relation, _, query, overwrite, ifNotExists) => + val writer = rel.writer(query.schema, DataSourceV2Utils.saveMode(overwrite, ifNotExists)) + if (writer.isDefined) { + WriteToDataSourceV2(writer.get, query) + } else { + NoopCommand() + } + case InsertIntoDir(_, storage, provider, query, overwrite) if provider.isDefined && provider.get.toLowerCase(Locale.ROOT) != DDLUtils.HIVE_PROVIDER => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5dbcf4a915cbf..0b00078298c3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.{AtomicType, StructType} @@ -392,6 +393,9 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit case LogicalRelation(_: InsertableRelation, _, catalogTable, _) => val tblName = catalogTable.map(_.identifier.quotedString).getOrElse("unknown") preprocess(i, tblName, Nil) + case relation: DataSourceV2Relation => + val tableName = relation.table.map(_.toString).orElse(relation.path).getOrElse("unknown") + preprocess(i, tableName, Nil) case _ => i } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 5c992866df6b8..934037022c967 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.Strategy -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { @@ -29,6 +29,10 @@ object DataSourceV2Strategy extends Strategy { case relation: StreamingDataSourceV2Relation => DataSourceV2ScanExec(relation.fullOutput, relation.reader) :: Nil + case InsertIntoTable(relation: DataSourceV2Relation, _, query, overwrite, ifNotExists) => + val mode = DataSourceV2Utils.saveMode(overwrite, ifNotExists) + WriteToDataSourceV2Exec(relation.writer(query.schema, mode).get, planLater(query)) :: Nil + case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index a506aef891535..9f6c01386a9e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.regex.Pattern import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} @@ -57,4 +56,30 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => Map.empty } + + def overwriteAndIfNotExists(mode: SaveMode): (Boolean, Boolean) = { + mode match { + case SaveMode.Ignore => + (false, true) + case SaveMode.Append => + (false, false) + case SaveMode.Overwrite => + (true, false) + case SaveMode.ErrorIfExists => + (true, true) + } + } + + def saveMode(overwrite: Boolean, ifNotExists: Boolean): SaveMode = { + (overwrite, ifNotExists) match { + case (false, true) => + SaveMode.Ignore + case (false, false) => + SaveMode.Append + case (true, false) => + SaveMode.Overwrite + case (true, true) => + SaveMode.ErrorIfExists + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index af629fb9038b5..593ed361e597a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -259,6 +259,22 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("SPARK-23321: test preprocess table insertion is applied") { + Seq(classOf[SimpleWritableDataSource]).foreach { cls => + withTempPath { file => + val path = file.getCanonicalPath + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + // attempt to write more columns than the table contains + val e = intercept[AnalysisException] { + spark.range(10).select('id, -'id, 'id * 'id) + .write.format(cls.getName).option("path", path).save() + } + assert(e.message.contains("data to be inserted have the same number of columns")) + } + } + } + test("simple counter in writer with onDataWriterCommit") { Seq(classOf[SimpleWritableDataSource]).foreach { cls => withTempPath { file =>