diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCreateTable.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCreateTable.java new file mode 100644 index 0000000000000..454639b16b22a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCreateTable.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +import java.util.Map; + +/** + * An interface that can be extended by DataSources that implement the {@link TableProvider} + * that can create new tables for the given options. These tables are not stored in any catalog, + * but have a mechanism to check whether a table can be created for the specified data source + * options. + */ +@Evolving +public interface SupportsCreateTable extends TableProvider { + + /** + * Check whether a new table can be created for the given options. + * + * @param options The options that should be sufficient to define and access a table + * @return true if the table exists, false otherwise + */ + boolean canCreateTable(CaseInsensitiveStringMap options); + + /** + * Create a table with the given options. It is the data source's responsibility to check if + * the provided schema and the transformations are acceptable in case a table already exists + * for the given options. + * + * @param options The data source options that define how to access the table. This can contain + * the path for file based tables, kafka broker addresses to connect to Kafka or + * the JDBC URL to connect to a JDBC data source. + * @param schema The schema of the new table, as a struct type + * @param partitions Transforms to use for partitioning data in the table + * @param properties A string map of table properties + * @return Metadata for the new table. The table creation can be followed up by a write + * @throws IllegalArgumentException If a table already exists for these options with a + * non-conforming schema or different partitioning specification. + * @throws UnsupportedOperationException If a requested partition transform is not supported or + * table properties are not supported + */ + Table buildTable( + CaseInsensitiveStringMap options, + StructType schema, + Transform[] partitions, + Map properties); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 58acfb836b305..4bfe1fac23ed1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect} import org.apache.spark.sql.catalyst.plans.logical.sql.InsertIntoStatement import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, SupportsWrite, TableCatalog, TableProvider, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, SupportsCreateTable, SupportsWrite, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LiteralValue, Transform} import org.apache.spark.sql.execution.SQLExecution @@ -249,49 +249,80 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { "write files of Hive data source directly.") } - assertNotBucketed("save") - val maybeV2Provider = lookupV2Provider() if (maybeV2Provider.isDefined) { - if (partitioningColumns.nonEmpty) { - throw new AnalysisException( - "Cannot write data to TableProvider implementation if partition columns are specified.") - } - val provider = maybeV2Provider.get val sessionOptions = DataSourceV2Utils.extractSessionConfigs( provider, df.sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions val dsOptions = new CaseInsensitiveStringMap(options.asJava) - import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ - provider.getTable(dsOptions) match { - case table: SupportsWrite if table.supports(BATCH_WRITE) => - lazy val relation = DataSourceV2Relation.create(table, dsOptions) - modeForDSV2 match { - case SaveMode.Append => - runCommand(df.sparkSession, "save") { - AppendData.byName(relation, df.logicalPlan) + provider match { + case supportsCreate: SupportsCreateTable => + val canCreate = supportsCreate.canCreateTable(dsOptions) + if (modeForDSV1 == SaveMode.ErrorIfExists && !canCreate) { + throw new AnalysisException(s"Table already exists.") + } else if (modeForDSV1 == SaveMode.Ignore && !canCreate) { + // do nothing + return + } + supportsCreate.buildTable( + dsOptions, + df.schema.asNullable, + getV2Transforms(), + Map.empty[String, String].asJava) match { + case table: SupportsWrite if table.supports(BATCH_WRITE) => + lazy val relation = DataSourceV2Relation.create(table, dsOptions) + modeForDSV1 match { + case SaveMode.Append | SaveMode.ErrorIfExists | SaveMode.Ignore => + runCommand(df.sparkSession, "save") { + AppendData.byName(relation, df.logicalPlan) + } + + case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => + // truncate the table + runCommand(df.sparkSession, "save") { + OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) + } } + } - case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => - // truncate the table - runCommand(df.sparkSession, "save") { - OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) + case _: TableProvider => + if (partitioningColumns.nonEmpty) { + throw new AnalysisException("Cannot write data to TableProvider implementation " + + "if partition columns are specified.") + } + assertNotBucketed("save") + + provider.getTable(dsOptions) match { + case table: SupportsWrite if table.supports(BATCH_WRITE) => + lazy val relation = DataSourceV2Relation.create(table, dsOptions) + modeForDSV2 match { + case SaveMode.Append => + runCommand(df.sparkSession, "save") { + AppendData.byName(relation, df.logicalPlan) + } + + case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => + // truncate the table + runCommand(df.sparkSession, "save") { + OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) + } + + case other => + throw new AnalysisException(s"TableProvider implementation $source cannot be " + + s"written with $other mode, please use Append or Overwrite " + + "modes instead.") } - case other => - throw new AnalysisException(s"TableProvider implementation $source cannot be " + - s"written with $other mode, please use Append or Overwrite " + - "modes instead.") + // Streaming also uses the data source V2 API. So it may be that the data source + // implements v2, but has no v2 implementation for batch writes. In that case, we fall + // back to saving as though it's a V1 source. + case _ => saveToV1Source() } - - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving - // as though it's a V1 source. - case _ => saveToV1Source() } } else { + assertNotBucketed("save") saveToV1Source() } } @@ -508,13 +539,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def saveAsTable(catalog: TableCatalog, ident: Identifier, mode: SaveMode): Unit = { - val partitioning = partitioningColumns.map { colNames => - colNames.map(name => IdentityTransform(FieldReference(name))) - }.getOrElse(Seq.empty[Transform]) - val bucketing = bucketColumnNames.map { cols => - Seq(BucketTransform(LiteralValue(numBuckets.get, IntegerType), cols.map(FieldReference(_)))) - }.getOrElse(Seq.empty[Transform]) - val partitionTransforms = partitioning ++ bucketing + val partitionTransforms = getV2Transforms() val tableOpt = try Option(catalog.loadTable(ident)) catch { case _: NoSuchTableException => None @@ -627,6 +652,22 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { CreateTable(tableDesc, modeForDSV1, Some(df.logicalPlan))) } + private def getV2Transforms(): Array[Transform] = { + val partitioning = partitioningColumns.map { colNames => + colNames.map(name => IdentityTransform(FieldReference(name))) + }.getOrElse(Seq.empty[Transform]) + getBucketSpec + val bucketing = getBucketSpec.map { spec => + if (spec.sortColumnNames.nonEmpty) { + throw new UnsupportedOperationException("V2 tables don't support bucketing with sorting.") + } + val cols = spec.bucketColumnNames + Seq(BucketTransform(LiteralValue(numBuckets.get, IntegerType), cols.map(FieldReference(_)))) + }.getOrElse(Seq.empty[Transform]) + + (partitioning ++ bucketing).toArray + } + /** * Saves the content of the `DataFrame` to an external database table via JDBC. In the case the * table already exists in the external database, behavior of this function depends on the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index cd811bb7afb51..cc3442d059b1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.connector -import org.apache.spark.sql.{DataFrame, Row, SaveMode} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LiteralValue} +import org.apache.spark.sql.types.IntegerType class DataSourceV2DataFrameSuite extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = false) { @@ -26,11 +28,13 @@ class DataSourceV2DataFrameSuite before { spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName) + InMemoryV1Provider.clear() } after { spark.sessionState.catalogManager.reset() spark.sessionState.conf.clear() + InMemoryV1Provider.clear() } override protected val catalogAndNamespace: String = "testcat.ns1.ns2.tbls" @@ -122,4 +126,68 @@ class DataSourceV2DataFrameSuite checkAnswer(spark.table(t1), Seq(Row("c", "d"))) } } + + SaveMode.values().foreach { mode => + test(s"save: new table creations with partitioning for table - mode: $mode") { + val format = classOf[InMemoryV1Provider].getName + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + df.write.mode(mode).option("name", "t1").format(format).partitionBy("a").save() + + checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df) + assert(InMemoryV1Provider.tables("t1").schema === df.schema.asNullable) + assert(InMemoryV1Provider.tables("t1").partitioning.sameElements( + Array(IdentityTransform(FieldReference(Seq("a")))))) + } + + test(s"save: new table creations with bucketing for table - mode: $mode") { + val format = classOf[InMemoryV1Provider].getName + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + df.write.mode(mode).option("name", "t1").format(format).bucketBy(2, "a").save() + + checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df) + assert(InMemoryV1Provider.tables("t1").schema === df.schema.asNullable) + assert(InMemoryV1Provider.tables("t1").partitioning.sameElements( + Array(BucketTransform(LiteralValue(2, IntegerType), Seq(FieldReference(Seq("a"))))))) + } + } + + test("save: default mode is ErrorIfExists") { + val format = classOf[InMemoryV1Provider].getName + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + + df.write.option("name", "t1").format(format).partitionBy("a").save() + // default is ErrorIfExists, and since a table already exists we throw an exception + val e = intercept[AnalysisException] { + df.write.option("name", "t1").format(format).partitionBy("a").save() + } + assert(e.getMessage.contains("already exists")) + } + + test("save: Ignore mode") { + val format = classOf[InMemoryV1Provider].getName + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + + df.write.option("name", "t1").format(format).partitionBy("a").save() + // no-op + df.write.option("name", "t1").format(format).mode("ignore").partitionBy("a").save() + + checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df) + } + + test("save: tables can perform schema and partitioning checks if they already exist") { + val format = classOf[InMemoryV1Provider].getName + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + + df.write.option("name", "t1").format(format).partitionBy("a").save() + val e2 = intercept[IllegalArgumentException] { + df.write.mode("append").option("name", "t1").format(format).partitionBy("b").save() + } + assert(e2.getMessage.contains("partitioning")) + + val e3 = intercept[IllegalArgumentException] { + Seq((1, "x")).toDF("c", "d").write.mode("append").option("name", "t1").format(format) + .save() + } + assert(e3.getMessage.contains("schema")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 7cd6ba21b56ec..83f76e7433d42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -25,8 +25,8 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode, SparkSession} -import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider} -import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.connector.write.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} import org.apache.spark.sql.sources.{DataSourceRegister, Filter, InsertableRelation} import org.apache.spark.sql.test.SharedSparkSession @@ -114,19 +114,45 @@ private object InMemoryV1Provider { } } -class InMemoryV1Provider extends TableProvider with DataSourceRegister { +class InMemoryV1Provider extends TableProvider with DataSourceRegister with SupportsCreateTable { override def getTable(options: CaseInsensitiveStringMap): Table = { - InMemoryV1Provider.tables.getOrElseUpdate(options.get("name"), { + InMemoryV1Provider.tables.getOrElse(options.get("name"), { new InMemoryTableWithV1Fallback( - "InMemoryTableWithV1Fallback", - new StructType().add("a", IntegerType).add("b", StringType), - Array(IdentityTransform(FieldReference(Seq("a")))), + "EmptyInMemoryTableWithV1Fallback", + new StructType(), + Array.empty, options.asCaseSensitiveMap() ) }) } override def shortName(): String = "in-memory" + + override def canCreateTable(options: CaseInsensitiveStringMap): Boolean = { + !InMemoryV1Provider.tables.contains(options.get("name")) + } + + override def buildTable( + options: CaseInsensitiveStringMap, + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + val t = InMemoryV1Provider.tables.getOrElseUpdate(options.get("name"), { + new InMemoryTableWithV1Fallback( + "InMemoryTableWithV1Fallback", + schema, + partitioning, + properties + ) + }) + if (t.schema != schema) { + throw new IllegalArgumentException("Wrong schema provided") + } + if (!t.partitioning.sameElements(partitioning)) { + throw new IllegalArgumentException("Wrong partitioning provided") + } + t + } } class InMemoryTableWithV1Fallback( @@ -136,8 +162,8 @@ class InMemoryTableWithV1Fallback( override val properties: util.Map[String, String]) extends Table with SupportsWrite { partitioning.foreach { t => - if (!t.isInstanceOf[IdentityTransform]) { - throw new IllegalArgumentException(s"Transform $t must be IdentityTransform") + if (!t.isInstanceOf[IdentityTransform] && !t.isInstanceOf[BucketTransform]) { + throw new IllegalArgumentException(s"Transform $t must be IdentityTransform or Bucketing") } }