From cad1ab78026d7139548149dbe2b4c36dd865d0aa Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 12 Dec 2019 17:37:39 -0800 Subject: [PATCH 1/5] DSV2 --- .../spark/sql/connector/InMemoryTable.scala | 37 +++++++- .../spark/sql/DataFrameWriterV2Suite.scala | 86 ++++++++++++++++++- 2 files changed, 118 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index c9e4e0aad5704..80ea012a47286 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -26,7 +26,7 @@ import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.sql.connector.expressions.{IdentityTransform, Transform} +import org.apache.spark.sql.connector.expressions.{IdentityTransform, NamedReference, Transform} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull} @@ -59,10 +59,39 @@ class InMemoryTable( def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq - private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames) - private val partIndexes = partFieldNames.map(schema.fieldIndex) + private val partRefs: Array[NamedReference] = partitioning.flatMap(_.references) - private def getKey(row: InternalRow): Seq[Any] = partIndexes.map(row.toSeq(schema)(_)) + private val partFieldNames: Seq[String] = partRefs.map { ref => + schema.findNestedField(ref.fieldNames(), includeCollections = false) match { + case Some(_) => ref.describe() + case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.") + } + } + + private val partCols: Array[Array[String]] = partRefs.map { ref => + schema.findNestedField(ref.fieldNames(), includeCollections = false) match { + case Some(_) => ref.fieldNames() + case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.") + } + } + + private def getKey(row: InternalRow): Seq[Any] = { + def extractor(fieldNames: Array[String], schema: StructType, values: Seq[Any]): Any = { + val index = schema.fieldIndex(fieldNames(0)) + val value = values(index) + if (fieldNames.length > 1) { + (value, schema(index).dataType) match { + case (value: InternalRow, nestedSchema: StructType) => + extractor(fieldNames.slice(1, fieldNames.length), + nestedSchema, value.toSeq(nestedSchema)) + case _ => throw new IllegalArgumentException(s"does not exist.") + } + } else { + value + } + } + partCols.map(filedNames => extractor(filedNames, schema, row.toSeq(schema))) + } def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { data.foreach(_.rows.foreach { row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index d49dc58e93ddb..9f8b1011dc1a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -17,20 +17,24 @@ package org.apache.spark.sql +import java.sql.Timestamp + import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} -import org.apache.spark.sql.connector.InMemoryTableCatalog +import org.apache.spark.sql.connector.{InMemoryTable, InMemoryTableCatalog} import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.types.TimestampType import org.apache.spark.sql.util.QueryExecutionListener +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -550,4 +554,84 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo assert(replaced.partitioning.isEmpty) assert(replaced.properties === defaultOwnership.asJava) } + + test("SPARK-30289 Create: partitioned by nested column") { + val schema = new StructType().add("ts", new StructType() + .add("created", TimestampType) + .add("modified", TimestampType) + .add("timezone", StringType)) + + val data = Seq( + Row(Row(Timestamp.valueOf("2019-06-01 10:00:00"), Timestamp.valueOf("2019-09-02 07:00:00"), + "America/Los_Angeles")), + Row(Row(Timestamp.valueOf("2019-08-26 18:00:00"), Timestamp.valueOf("2019-09-26 18:00:00"), + "America/Los_Angeles")), + Row(Row(Timestamp.valueOf("2018-11-23 18:00:00"), Timestamp.valueOf("2018-12-22 18:00:00"), + "America/New_York"))) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) + + df.writeTo("testcat.table_name") + .partitionedBy($"ts.timezone") + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + .asInstanceOf[InMemoryTable] + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(IdentityTransform(FieldReference(Array("ts", "timezone"))))) + checkAnswer(spark.table(table.name), data) + assert(table.dataMap.toArray.length == 2) + assert(table.dataMap(Seq(UTF8String.fromString("America/Los_Angeles"))).rows.size == 2) + assert(table.dataMap(Seq(UTF8String.fromString("America/New_York"))).rows.size == 1) + + // TODO: `DataSourceV2Strategy` can not translate nested fields into source filter yet + // so the following sql will fail. + // sql("DELETE FROM testcat.table_name WHERE ts.timezone = \"America/Los_Angeles\"") + } + + test("SPARK-30289 Create: partitioned by multiple transforms on nested columns") { + spark.table("source") + .withColumn("ts", struct( + lit("2019-06-01 10:00:00.000000").cast("timestamp") as "created", + lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", + lit("America/Los_Angeles") as "timezone")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy( + years($"ts.created"), months($"ts.created"), days($"ts.created"), hours($"ts.created"), + years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), hours($"ts.modified") + ) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq( + YearsTransform(FieldReference(Array("ts", "created"))), + MonthsTransform(FieldReference(Array("ts", "created"))), + DaysTransform(FieldReference(Array("ts", "created"))), + HoursTransform(FieldReference(Array("ts", "created"))), + YearsTransform(FieldReference(Array("ts", "modified"))), + MonthsTransform(FieldReference(Array("ts", "modified"))), + DaysTransform(FieldReference(Array("ts", "modified"))), + HoursTransform(FieldReference(Array("ts", "modified"))))) + } + + test("SPARK-30289 Create a: partitioned by bucket(4, ts.timezone)") { + spark.table("source") + .withColumn("ts", struct( + lit("2019-06-01 10:00:00.000000").cast("timestamp") as "created", + lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", + lit("America/Los_Angeles") as "timezone")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(bucket(4, $"ts.timezone")) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(BucketTransform(LiteralValue(4, IntegerType), + Seq(FieldReference(Seq("ts", "timezone")))))) + } } From 259d9d344c938a789e50ca3f925c1a89c20417f5 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 13 Feb 2020 15:01:06 -0800 Subject: [PATCH 2/5] refacting --- .../spark/sql/connector/InMemoryTable.scala | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 80ea012a47286..d9c553c864fb8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -59,16 +59,7 @@ class InMemoryTable( def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq - private val partRefs: Array[NamedReference] = partitioning.flatMap(_.references) - - private val partFieldNames: Seq[String] = partRefs.map { ref => - schema.findNestedField(ref.fieldNames(), includeCollections = false) match { - case Some(_) => ref.describe() - case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.") - } - } - - private val partCols: Array[Array[String]] = partRefs.map { ref => + private val partCols: Array[Array[String]] = partitioning.flatMap(_.references).map { ref => schema.findNestedField(ref.fieldNames(), includeCollections = false) match { case Some(_) => ref.fieldNames() case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.") @@ -76,21 +67,21 @@ class InMemoryTable( } private def getKey(row: InternalRow): Seq[Any] = { - def extractor(fieldNames: Array[String], schema: StructType, values: Seq[Any]): Any = { + def extractor(fieldNames: Array[String], schema: StructType, row: InternalRow): Any = { val index = schema.fieldIndex(fieldNames(0)) - val value = values(index) + val value = row.toSeq(schema).apply(index) if (fieldNames.length > 1) { (value, schema(index).dataType) match { - case (value: InternalRow, nestedSchema: StructType) => - extractor(fieldNames.slice(1, fieldNames.length), - nestedSchema, value.toSeq(nestedSchema)) - case _ => throw new IllegalArgumentException(s"does not exist.") + case (row: InternalRow, nestedSchema: StructType) => + extractor(fieldNames.slice(1, fieldNames.length), nestedSchema, row) + case (_, dataType) => + throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}") } } else { value } } - partCols.map(filedNames => extractor(filedNames, schema, row.toSeq(schema))) + partCols.map(filedNames => extractor(filedNames, schema, row)) } def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { @@ -175,8 +166,10 @@ class InMemoryTable( } private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { - val deleteKeys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) + val deleteKeys = InMemoryTable.filtersToKeys( + dataMap.keys, partCols.map(_.toSeq.quoted), filters) dataMap --= deleteKeys withData(messages.map(_.asInstanceOf[BufferedRows])) } @@ -190,7 +183,8 @@ class InMemoryTable( } override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized { - dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters) } } From ef4e21a5f68b76eff84f6cf8f8aa581f1e081343 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 14 Feb 2020 00:38:51 -0800 Subject: [PATCH 3/5] fix typo --- .../scala/org/apache/spark/sql/connector/InMemoryTable.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index d9c553c864fb8..5564e907572c5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -81,7 +81,7 @@ class InMemoryTable( value } } - partCols.map(filedNames => extractor(filedNames, schema, row)) + partCols.map(fieldNames => extractor(fieldNames, schema, row)) } def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { From 21ebd26d01ba126ddbcbc45d3268d5b922ff9ac1 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 14 Feb 2020 00:46:55 -0800 Subject: [PATCH 4/5] typo --- .../scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index 9f8b1011dc1a6..cd157086a8b8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -617,7 +617,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo HoursTransform(FieldReference(Array("ts", "modified"))))) } - test("SPARK-30289 Create a: partitioned by bucket(4, ts.timezone)") { + test("SPARK-30289 Create: partitioned by bucket(4, ts.timezone)") { spark.table("source") .withColumn("ts", struct( lit("2019-06-01 10:00:00.000000").cast("timestamp") as "created", From e2cd87fc79c5df60d95a3848f57ba7eca9401cdb Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 14 Feb 2020 01:02:20 -0800 Subject: [PATCH 5/5] syntax sugar --- .../scala/org/apache/spark/sql/connector/InMemoryTable.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 5564e907572c5..0187ae31e2d1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -73,7 +73,7 @@ class InMemoryTable( if (fieldNames.length > 1) { (value, schema(index).dataType) match { case (row: InternalRow, nestedSchema: StructType) => - extractor(fieldNames.slice(1, fieldNames.length), nestedSchema, row) + extractor(fieldNames.drop(1), nestedSchema, row) case (_, dataType) => throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}") }