Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.catalog;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException;
import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

import java.util.Map;

/**
* An interface that can be extended by DataSources that implement the {@link TableProvider}
* that can create new tables for the given options. These tables are not stored in any catalog,
* but have a mechanism to check whether a table can be created for the specified data source
* options.
*/
@Evolving
public interface SupportsCreateTable extends TableProvider {

/**
* Check whether a new table can be created for the given options.
*
* @param options The options that should be sufficient to define and access a table
* @return true if the table exists, false otherwise
*/
boolean canCreateTable(CaseInsensitiveStringMap options);

/**
* Create a table with the given options. It is the data source's responsibility to check if
* the provided schema and the transformations are acceptable in case a table already exists
* for the given options.
*
* @param options The data source options that define how to access the table. This can contain
* the path for file based tables, kafka broker addresses to connect to Kafka or
* the JDBC URL to connect to a JDBC data source.
* @param schema The schema of the new table, as a struct type
* @param partitions Transforms to use for partitioning data in the table
* @param properties A string map of table properties
* @return Metadata for the new table. The table creation can be followed up by a write
* @throws IllegalArgumentException If a table already exists for these options with a
* non-conforming schema or different partitioning specification.
* @throws UnsupportedOperationException If a requested partition transform is not supported or
* table properties are not supported
*/
Table buildTable(
CaseInsensitiveStringMap options,
StructType schema,
Transform[] partitions,
Map<String, String> properties);
}
113 changes: 77 additions & 36 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect}
import org.apache.spark.sql.catalyst.plans.logical.sql.InsertIntoStatement
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, SupportsWrite, TableCatalog, TableProvider, V1Table}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, SupportsCreateTable, SupportsWrite, TableCatalog, TableProvider, V1Table}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LiteralValue, Transform}
import org.apache.spark.sql.execution.SQLExecution
Expand Down Expand Up @@ -249,49 +249,80 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
"write files of Hive data source directly.")
}

assertNotBucketed("save")

val maybeV2Provider = lookupV2Provider()
if (maybeV2Provider.isDefined) {
if (partitioningColumns.nonEmpty) {
throw new AnalysisException(
"Cannot write data to TableProvider implementation if partition columns are specified.")
}

val provider = maybeV2Provider.get
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
provider, df.sparkSession.sessionState.conf)
val options = sessionOptions ++ extraOptions
val dsOptions = new CaseInsensitiveStringMap(options.asJava)

import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
provider.getTable(dsOptions) match {
case table: SupportsWrite if table.supports(BATCH_WRITE) =>
lazy val relation = DataSourceV2Relation.create(table, dsOptions)
modeForDSV2 match {
case SaveMode.Append =>
runCommand(df.sparkSession, "save") {
AppendData.byName(relation, df.logicalPlan)
provider match {
case supportsCreate: SupportsCreateTable =>
val canCreate = supportsCreate.canCreateTable(dsOptions)
if (modeForDSV1 == SaveMode.ErrorIfExists && !canCreate) {
throw new AnalysisException(s"Table already exists.")
} else if (modeForDSV1 == SaveMode.Ignore && !canCreate) {
// do nothing
return
}
supportsCreate.buildTable(
dsOptions,
df.schema.asNullable,
getV2Transforms(),
Map.empty[String, String].asJava) match {
case table: SupportsWrite if table.supports(BATCH_WRITE) =>
lazy val relation = DataSourceV2Relation.create(table, dsOptions)
modeForDSV1 match {
case SaveMode.Append | SaveMode.ErrorIfExists | SaveMode.Ignore =>
runCommand(df.sparkSession, "save") {
AppendData.byName(relation, df.logicalPlan)
}

case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) =>
// truncate the table
runCommand(df.sparkSession, "save") {
OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true))
}
}
}

case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) =>
// truncate the table
runCommand(df.sparkSession, "save") {
OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true))
case _: TableProvider =>
if (partitioningColumns.nonEmpty) {
throw new AnalysisException("Cannot write data to TableProvider implementation " +
"if partition columns are specified.")
}
assertNotBucketed("save")

provider.getTable(dsOptions) match {
case table: SupportsWrite if table.supports(BATCH_WRITE) =>
lazy val relation = DataSourceV2Relation.create(table, dsOptions)
modeForDSV2 match {
case SaveMode.Append =>
runCommand(df.sparkSession, "save") {
AppendData.byName(relation, df.logicalPlan)
}

case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) =>
// truncate the table
runCommand(df.sparkSession, "save") {
OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true))
}

case other =>
throw new AnalysisException(s"TableProvider implementation $source cannot be " +
s"written with $other mode, please use Append or Overwrite " +
"modes instead.")
}

case other =>
throw new AnalysisException(s"TableProvider implementation $source cannot be " +
s"written with $other mode, please use Append or Overwrite " +
"modes instead.")
// Streaming also uses the data source V2 API. So it may be that the data source
// implements v2, but has no v2 implementation for batch writes. In that case, we fall
// back to saving as though it's a V1 source.
case _ => saveToV1Source()
}

// Streaming also uses the data source V2 API. So it may be that the data source implements
// v2, but has no v2 implementation for batch writes. In that case, we fall back to saving
// as though it's a V1 source.
case _ => saveToV1Source()
}
} else {
assertNotBucketed("save")
saveToV1Source()
}
}
Expand Down Expand Up @@ -508,13 +539,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {


private def saveAsTable(catalog: TableCatalog, ident: Identifier, mode: SaveMode): Unit = {
val partitioning = partitioningColumns.map { colNames =>
colNames.map(name => IdentityTransform(FieldReference(name)))
}.getOrElse(Seq.empty[Transform])
val bucketing = bucketColumnNames.map { cols =>
Seq(BucketTransform(LiteralValue(numBuckets.get, IntegerType), cols.map(FieldReference(_))))
}.getOrElse(Seq.empty[Transform])
val partitionTransforms = partitioning ++ bucketing
val partitionTransforms = getV2Transforms()

val tableOpt = try Option(catalog.loadTable(ident)) catch {
case _: NoSuchTableException => None
Expand Down Expand Up @@ -627,6 +652,22 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
CreateTable(tableDesc, modeForDSV1, Some(df.logicalPlan)))
}

private def getV2Transforms(): Array[Transform] = {
val partitioning = partitioningColumns.map { colNames =>
colNames.map(name => IdentityTransform(FieldReference(name)))
}.getOrElse(Seq.empty[Transform])
getBucketSpec
val bucketing = getBucketSpec.map { spec =>
if (spec.sortColumnNames.nonEmpty) {
throw new UnsupportedOperationException("V2 tables don't support bucketing with sorting.")
}
val cols = spec.bucketColumnNames
Seq(BucketTransform(LiteralValue(numBuckets.get, IntegerType), cols.map(FieldReference(_))))
}.getOrElse(Seq.empty[Transform])

(partitioning ++ bucketing).toArray
}

/**
* Saves the content of the `DataFrame` to an external database table via JDBC. In the case the
* table already exists in the external database, behavior of this function depends on the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.spark.sql.connector

import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LiteralValue}
import org.apache.spark.sql.types.IntegerType

class DataSourceV2DataFrameSuite
extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = false) {
Expand All @@ -26,11 +28,13 @@ class DataSourceV2DataFrameSuite
before {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName)
InMemoryV1Provider.clear()
}

after {
spark.sessionState.catalogManager.reset()
spark.sessionState.conf.clear()
InMemoryV1Provider.clear()
}

override protected val catalogAndNamespace: String = "testcat.ns1.ns2.tbls"
Expand Down Expand Up @@ -122,4 +126,68 @@ class DataSourceV2DataFrameSuite
checkAnswer(spark.table(t1), Seq(Row("c", "d")))
}
}

SaveMode.values().foreach { mode =>
test(s"save: new table creations with partitioning for table - mode: $mode") {
val format = classOf[InMemoryV1Provider].getName
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
df.write.mode(mode).option("name", "t1").format(format).partitionBy("a").save()

checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df)
assert(InMemoryV1Provider.tables("t1").schema === df.schema.asNullable)
assert(InMemoryV1Provider.tables("t1").partitioning.sameElements(
Array(IdentityTransform(FieldReference(Seq("a"))))))
}

test(s"save: new table creations with bucketing for table - mode: $mode") {
val format = classOf[InMemoryV1Provider].getName
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
df.write.mode(mode).option("name", "t1").format(format).bucketBy(2, "a").save()

checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df)
assert(InMemoryV1Provider.tables("t1").schema === df.schema.asNullable)
assert(InMemoryV1Provider.tables("t1").partitioning.sameElements(
Array(BucketTransform(LiteralValue(2, IntegerType), Seq(FieldReference(Seq("a")))))))
}
}

test("save: default mode is ErrorIfExists") {
val format = classOf[InMemoryV1Provider].getName
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")

df.write.option("name", "t1").format(format).partitionBy("a").save()
// default is ErrorIfExists, and since a table already exists we throw an exception
val e = intercept[AnalysisException] {
df.write.option("name", "t1").format(format).partitionBy("a").save()
}
assert(e.getMessage.contains("already exists"))
}

test("save: Ignore mode") {
val format = classOf[InMemoryV1Provider].getName
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")

df.write.option("name", "t1").format(format).partitionBy("a").save()
// no-op
df.write.option("name", "t1").format(format).mode("ignore").partitionBy("a").save()

checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df)
}

test("save: tables can perform schema and partitioning checks if they already exist") {
val format = classOf[InMemoryV1Provider].getName
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")

df.write.option("name", "t1").format(format).partitionBy("a").save()
val e2 = intercept[IllegalArgumentException] {
df.write.mode("append").option("name", "t1").format(format).partitionBy("b").save()
}
assert(e2.getMessage.contains("partitioning"))

val e3 = intercept[IllegalArgumentException] {
Seq((1, "x")).toDF("c", "d").write.mode("append").option("name", "t1").format(format)
.save()
}
assert(e3.getMessage.contains("schema"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import scala.collection.mutable
import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode, SparkSession}
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.connector.write.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder}
import org.apache.spark.sql.sources.{DataSourceRegister, Filter, InsertableRelation}
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -114,19 +114,45 @@ private object InMemoryV1Provider {
}
}

class InMemoryV1Provider extends TableProvider with DataSourceRegister {
class InMemoryV1Provider extends TableProvider with DataSourceRegister with SupportsCreateTable {
override def getTable(options: CaseInsensitiveStringMap): Table = {
InMemoryV1Provider.tables.getOrElseUpdate(options.get("name"), {
InMemoryV1Provider.tables.getOrElse(options.get("name"), {
new InMemoryTableWithV1Fallback(
"InMemoryTableWithV1Fallback",
new StructType().add("a", IntegerType).add("b", StringType),
Array(IdentityTransform(FieldReference(Seq("a")))),
"EmptyInMemoryTableWithV1Fallback",
new StructType(),
Array.empty,
options.asCaseSensitiveMap()
)
})
}

override def shortName(): String = "in-memory"

override def canCreateTable(options: CaseInsensitiveStringMap): Boolean = {
!InMemoryV1Provider.tables.contains(options.get("name"))
}

override def buildTable(
options: CaseInsensitiveStringMap,
schema: StructType,
partitioning: Array[Transform],
properties: util.Map[String, String]): Table = {
val t = InMemoryV1Provider.tables.getOrElseUpdate(options.get("name"), {
new InMemoryTableWithV1Fallback(
"InMemoryTableWithV1Fallback",
schema,
partitioning,
properties
)
})
if (t.schema != schema) {
throw new IllegalArgumentException("Wrong schema provided")
}
if (!t.partitioning.sameElements(partitioning)) {
throw new IllegalArgumentException("Wrong partitioning provided")
}
t
}
}

class InMemoryTableWithV1Fallback(
Expand All @@ -136,8 +162,8 @@ class InMemoryTableWithV1Fallback(
override val properties: util.Map[String, String]) extends Table with SupportsWrite {

partitioning.foreach { t =>
if (!t.isInstanceOf[IdentityTransform]) {
throw new IllegalArgumentException(s"Transform $t must be IdentityTransform")
if (!t.isInstanceOf[IdentityTransform] && !t.isInstanceOf[BucketTransform]) {
throw new IllegalArgumentException(s"Transform $t must be IdentityTransform or Bucketing")
}
}

Expand Down