From 73bd0e671d080000b907dba85a72b8a6d5eee834 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 15 Dec 2021 15:46:16 -0800 Subject: [PATCH 1/9] [SPARK-37627][SQL][FOLLOWUP] Add test for sorted BucketTransform --- .../expressions/TransformExtractorSuite.scala | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index b2371ce667ffc..a0d3793135892 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst +import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket import org.apache.spark.sql.types.DataType class TransformExtractorSuite extends SparkFunSuite { @@ -131,14 +132,22 @@ class TransformExtractorSuite extends SparkFunSuite { test("Bucket extractor") { val col = ref("a", "b") - val bucketTransform = new Transform { + val sortedCol = ref("c", "d") + val bucketTransform1 = new Transform { override def name: String = "bucket" override def references: Array[NamedReference] = Array(col) override def arguments: Array[Expression] = Array(lit(16), col) override def toString: String = s"bucket(16, ${col.describe})" } - bucketTransform match { + val sortedBucketTransform1 = new Transform { + override def name: String = "bucket" + override def references: Array[NamedReference] = Array(col) ++ Array(sortedCol) + override def arguments: Array[Expression] = Array(lit(16), col, sortedCol) + override def describe: String = s"bucket(16, ${col.describe} ${sortedCol.describe})" + } + + bucketTransform1 match { case BucketTransform(numBuckets, FieldReference(seq), _) => assert(numBuckets === 16) assert(seq === Seq("a", "b")) @@ -152,5 +161,33 @@ class TransformExtractorSuite extends SparkFunSuite { case _ => // expected } + + sortedBucketTransform1 match { + case BucketTransform(numBuckets, FieldReference(seq), FieldReference(sorted)) => + assert(numBuckets === 16) + assert(seq === Seq("a", "b")) + assert(sorted === Seq("c", "d")) + case _ => + fail("Did not match BucketTransform extractor") + } + + val bucketTransform2 = bucket(16, Array(col)) + val reference1 = bucketTransform2.references + assert(reference1.length == 1 && reference1(0).fieldNames() === Seq("a", "b")) + val arguments1 = bucketTransform2.arguments + assert(arguments1.length == 2) + assert(arguments1(0).asInstanceOf[LiteralValue[Integer]].value === 16) + assert(arguments1(1).asInstanceOf[NamedReference].fieldNames() === Seq("a", "b")) + + val sortedBucketTransform2 = bucket(16, Array(col), Array(sortedCol)) + val reference2 = sortedBucketTransform2.references + assert(reference2.length == 2) + assert(reference2(0).fieldNames() === Seq("a", "b")) + assert(reference2(1).fieldNames() === Seq("c", "d")) + val arguments2 = sortedBucketTransform2.arguments + assert(arguments2.length == 3) + assert(arguments2(0).asInstanceOf[LiteralValue[Integer]].value === 16) + assert(arguments2(1).asInstanceOf[NamedReference].fieldNames() === Seq("a", "b")) + assert(arguments2(2).asInstanceOf[NamedReference].fieldNames() === Seq("c", "d")) } } From 4db60f789cd9c78cfe4ddac1cbff3951b34e3ad1 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 2 Jan 2022 20:37:17 -0800 Subject: [PATCH 2/9] use literal to separate bucket cols and sorted cols --- .../expressions/TransformExtractorSuite.scala | 78 ++++++++++++------- .../sql/connector/DataSourceV2SQLSuite.scala | 14 ++-- 2 files changed, 59 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index a0d3793135892..35cded3e6d10e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -40,6 +40,7 @@ class TransformExtractorSuite extends SparkFunSuite { override def toString: String = names.mkString(".") } + /** * Creates a Transform using an anonymous class. */ @@ -132,22 +133,14 @@ class TransformExtractorSuite extends SparkFunSuite { test("Bucket extractor") { val col = ref("a", "b") - val sortedCol = ref("c", "d") - val bucketTransform1 = new Transform { + val bucketTransform = new Transform { override def name: String = "bucket" override def references: Array[NamedReference] = Array(col) override def arguments: Array[Expression] = Array(lit(16), col) override def toString: String = s"bucket(16, ${col.describe})" } - val sortedBucketTransform1 = new Transform { - override def name: String = "bucket" - override def references: Array[NamedReference] = Array(col) ++ Array(sortedCol) - override def arguments: Array[Expression] = Array(lit(16), col, sortedCol) - override def describe: String = s"bucket(16, ${col.describe} ${sortedCol.describe})" - } - - bucketTransform1 match { + bucketTransform match { case BucketTransform(numBuckets, FieldReference(seq), _) => assert(numBuckets === 16) assert(seq === Seq("a", "b")) @@ -161,8 +154,21 @@ class TransformExtractorSuite extends SparkFunSuite { case _ => // expected } + } - sortedBucketTransform1 match { + test("Sorted Bucket extractor") { + val col = Array(ref("a"), ref("b")) + val sortedCol = Array(ref("c"), ref("d")) + + val sortedBucketTransform = new Transform { + override def name: String = "sortedBucket" + override def references: Array[NamedReference] = col ++ sortedCol + override def arguments: Array[Expression] = (col :+ lit(16)) ++ sortedCol + override def describe: String = s"bucket(16, ${col(0).describe}, ${col(1).describe} " + + s"${sortedCol(0).describe} ${sortedCol(1).describe})" + } + + sortedBucketTransform match { case BucketTransform(numBuckets, FieldReference(seq), FieldReference(sorted)) => assert(numBuckets === 16) assert(seq === Seq("a", "b")) @@ -170,24 +176,40 @@ class TransformExtractorSuite extends SparkFunSuite { case _ => fail("Did not match BucketTransform extractor") } + } - val bucketTransform2 = bucket(16, Array(col)) - val reference1 = bucketTransform2.references - assert(reference1.length == 1 && reference1(0).fieldNames() === Seq("a", "b")) - val arguments1 = bucketTransform2.arguments - assert(arguments1.length == 2) + test("test bucket") { + val col = Array(ref("a"), ref("b")) + val sortedCol = Array(ref("c"), ref("d")) + + val bucketTransform = bucket(16, col) + val reference1 = bucketTransform.references + assert(reference1.length == 2) + assert(reference1(0).fieldNames() === Seq("a")) + assert(reference1(1).fieldNames() === Seq("b")) + val arguments1 = bucketTransform.arguments + assert(arguments1.length == 3) assert(arguments1(0).asInstanceOf[LiteralValue[Integer]].value === 16) - assert(arguments1(1).asInstanceOf[NamedReference].fieldNames() === Seq("a", "b")) - - val sortedBucketTransform2 = bucket(16, Array(col), Array(sortedCol)) - val reference2 = sortedBucketTransform2.references - assert(reference2.length == 2) - assert(reference2(0).fieldNames() === Seq("a", "b")) - assert(reference2(1).fieldNames() === Seq("c", "d")) - val arguments2 = sortedBucketTransform2.arguments - assert(arguments2.length == 3) - assert(arguments2(0).asInstanceOf[LiteralValue[Integer]].value === 16) - assert(arguments2(1).asInstanceOf[NamedReference].fieldNames() === Seq("a", "b")) - assert(arguments2(2).asInstanceOf[NamedReference].fieldNames() === Seq("c", "d")) + assert(arguments1(1).asInstanceOf[NamedReference].fieldNames() === Seq("a")) + assert(arguments1(2).asInstanceOf[NamedReference].fieldNames() === Seq("b")) + val copied1 = bucketTransform.withReferences(reference1) + assert(copied1.equals(bucketTransform)) + + val sortedBucketTransform = bucket(16, col, sortedCol) + val reference2 = sortedBucketTransform.references + assert(reference2.length == 4) + assert(reference2(0).fieldNames() === Seq("a")) + assert(reference2(1).fieldNames() === Seq("b")) + assert(reference2(2).fieldNames() === Seq("c")) + assert(reference2(3).fieldNames() === Seq("d")) + val arguments2 = sortedBucketTransform.arguments + assert(arguments2.length == 5) + assert(arguments2(0).asInstanceOf[NamedReference].fieldNames() === Seq("a")) + assert(arguments2(1).asInstanceOf[NamedReference].fieldNames() === Seq("b")) + assert(arguments2(2).asInstanceOf[LiteralValue[Integer]].value === 16) + assert(arguments2(3).asInstanceOf[NamedReference].fieldNames() === Seq("c")) + assert(arguments2(4).asInstanceOf[NamedReference].fieldNames() === Seq("d")) + val copied2 = sortedBucketTransform.withReferences(reference2) + assert(copied2.equals(sortedBucketTransform)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 90f9f157b0284..f1a8a2751f714 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1566,18 +1566,22 @@ class DataSourceV2SQLSuite test("create table using - with sorted bucket") { val identifier = "testcat.table_name" withTable(identifier) { - sql(s"CREATE TABLE $identifier (a int, b string, c int) USING $v2Source PARTITIONED BY (c)" + - s" CLUSTERED BY (b) SORTED by (a) INTO 4 BUCKETS") - val table = getTableMetadata(identifier) + sql(s"CREATE TABLE $identifier (a int, b string, c int, d int, e int, f int) USING" + + s" $v2Source PARTITIONED BY (a, b) CLUSTERED BY (c, d) SORTED by (e, f) INTO 4 BUCKETS") val describe = spark.sql(s"DESCRIBE $identifier") + describe.show(false) val part1 = describe .filter("col_name = 'Part 0'") .select("data_type").head.getString(0) - assert(part1 === "c") + assert(part1 === "a") val part2 = describe .filter("col_name = 'Part 1'") .select("data_type").head.getString(0) - assert(part2 === "bucket(4, b, a)") + assert(part2 === "b") + val part3 = describe + .filter("col_name = 'Part 2'") + .select("data_type").head.getString(0) + assert(part3 === "sortedBucket(c, d, 4, e, f)") } } From 6b9f42cf605efad24ef1119a697a2ecbea9e5885 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 3 Jan 2022 17:56:36 -0800 Subject: [PATCH 3/9] remove extra space and extra blank line --- .../sql/connector/expressions/TransformExtractorSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index 35cded3e6d10e..ec7919dd8959b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -40,7 +40,6 @@ class TransformExtractorSuite extends SparkFunSuite { override def toString: String = names.mkString(".") } - /** * Creates a Transform using an anonymous class. */ From 00a90daceba7af7ba528fcbc4a8eeb78db62fc48 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 4 Jan 2022 16:25:58 -0800 Subject: [PATCH 4/9] resolve conflict --- .../connector/expressions/expressions.scala | 48 ++++++++++++------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index e3eab6f6730f1..600cc0eef1aa2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -104,24 +104,29 @@ private[sql] final case class BucketTransform( columns: Seq[NamedReference], sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform { - override val name: String = "bucket" + override val name: String = if (sortedColumns.nonEmpty) "sortedBucket" else "bucket" override def references: Array[NamedReference] = { arguments.collect { case named: NamedReference => named } } - override def arguments: Array[Expression] = numBuckets +: columns.toArray - - override def toString: String = + override def arguments: Array[Expression] = { if (sortedColumns.nonEmpty) { - s"bucket(${arguments.map(_.describe).mkString(", ")}," + - s" ${sortedColumns.map(_.describe).mkString(", ")})" + (columns.toArray :+ numBuckets) ++ sortedColumns } else { - s"bucket(${arguments.map(_.describe).mkString(", ")})" + numBuckets +: columns.toArray } + } + + override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})" override def withReferences(newReferences: Seq[NamedReference]): Transform = { - this.copy(columns = newReferences) + if (sortedColumns.isEmpty) { + this.copy(columns = newReferences) + } else { + val splits = newReferences.grouped(columns.length).toList + this.copy(columns = splits(0), sortedColumns = splits(1)) + } } } @@ -140,19 +145,26 @@ private[sql] object BucketTransform { } def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] = - transform match { - case NamedTransform("bucket", Seq( - Lit(value: Int, IntegerType), - Ref(partCols: Seq[String]), - Ref(sortCols: Seq[String]))) => - Some((value, FieldReference(partCols), FieldReference(sortCols))) - case NamedTransform("bucket", Seq( - Lit(value: Int, IntegerType), - Ref(partCols: Seq[String]))) => + transform match { + case NamedTransform("sortedBucket", s) => + var index: Int = -1 + var posOfLit: Int = -1 + var numOfBucket: Int = -1 + s.foreach { + case Lit(value: Int, IntegerType) => + numOfBucket = value + index = index + 1 + posOfLit = index + case _ => index = index + 1 + } + val splits = s.splitAt(posOfLit) + Some(numOfBucket, FieldReference( + splits._1.map(_.describe)), FieldReference(splits._2.drop(1).map(_.describe))) + case NamedTransform("bucket", Seq(Lit(value: Int, IntegerType), Ref(partCols: Seq[String]))) => Some((value, FieldReference(partCols), FieldReference(Seq.empty[String]))) case _ => None - } + } } private[sql] final case class ApplyTransform( From 51ead2b5cdc7c89528f515c128b648bcf6cb4d7e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 4 Jan 2022 16:31:03 -0800 Subject: [PATCH 5/9] remove unnessary change --- .../apache/spark/sql/connector/expressions/expressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index 600cc0eef1aa2..cced6c4a4e2d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -164,7 +164,7 @@ private[sql] object BucketTransform { Some((value, FieldReference(partCols), FieldReference(Seq.empty[String]))) case _ => None - } + } } private[sql] final case class ApplyTransform( From 3f220d0262d50f73dc72f99bccfd2e1a3ab5902e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 5 Jan 2022 14:41:15 -0800 Subject: [PATCH 6/9] separate BucketTransform and SortedBucketTransform --- .../catalog/CatalogV2Implicits.scala | 6 +- .../connector/expressions/expressions.scala | 83 +++++++++++++------ .../sql/connector/catalog/InMemoryTable.scala | 7 +- .../expressions/TransformExtractorSuite.scala | 10 +-- .../datasources/v2/V2SessionCatalog.scala | 8 +- .../sql/connector/DataSourceV2SQLSuite.scala | 1 - 6 files changed, 77 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 185a1a2644e2f..605ae2c5075f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.quoteIfNeeded -import org.apache.spark.sql.connector.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform} +import org.apache.spark.sql.connector.expressions.{IdentityTransform, LogicalExpressions, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors /** @@ -37,11 +37,11 @@ private[sql] object CatalogV2Implicits { } implicit class BucketSpecHelper(spec: BucketSpec) { - def asTransform: BucketTransform = { + def asTransform: Transform = { val references = spec.bucketColumnNames.map(col => reference(Seq(col))) if (spec.sortColumnNames.nonEmpty) { val sortedCol = spec.sortColumnNames.map(col => reference(Seq(col))) - bucket(spec.numBuckets, references.toArray, sortedCol.toArray) + sortedBucket(spec.numBuckets, references.toArray, sortedCol.toArray) } else { bucket(spec.numBuckets, references.toArray) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index cced6c4a4e2d1..238a51412ffda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -45,11 +45,11 @@ private[sql] object LogicalExpressions { def bucket(numBuckets: Int, references: Array[NamedReference]): BucketTransform = BucketTransform(literal(numBuckets, IntegerType), references) - def bucket( + def sortedBucket( numBuckets: Int, references: Array[NamedReference], - sortedCols: Array[NamedReference]): BucketTransform = - BucketTransform(literal(numBuckets, IntegerType), references, sortedCols) + sortedCols: Array[NamedReference]): SortedBucketTransform = + SortedBucketTransform(literal(numBuckets, IntegerType), references, sortedCols) def identity(reference: NamedReference): IdentityTransform = IdentityTransform(reference) @@ -100,42 +100,76 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R } private[sql] final case class BucketTransform( + numBuckets: Literal[Int], + columns: Seq[NamedReference]) extends RewritableTransform { + + override val name: String = "bucket" + + override def references: Array[NamedReference] = { + arguments.collect { case named: NamedReference => named } + } + + override def arguments: Array[Expression] = numBuckets +: columns.toArray + + override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})" + + override def toString: String = describe + + override def withReferences(newReferences: Seq[NamedReference]): Transform = { + this.copy(columns = newReferences) + } +} + +private[sql] object BucketTransform { + def unapply(expr: Expression): Option[(Int, FieldReference)] = expr match { + case transform: Transform => + transform match { + case BucketTransform(n, FieldReference(parts)) => + Some((n, FieldReference(parts))) + case _ => + None + } + case _ => + None + } + + def unapply(transform: Transform): Option[(Int, NamedReference)] = transform match { + case NamedTransform("bucket", Seq( + Lit(value: Int, IntegerType), + Ref(seq: Seq[String]))) => + Some((value, FieldReference(seq))) + case _ => + None + } +} + +private[sql] final case class SortedBucketTransform( numBuckets: Literal[Int], columns: Seq[NamedReference], sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform { - override val name: String = if (sortedColumns.nonEmpty) "sortedBucket" else "bucket" + override val name: String = "sortedBucket" override def references: Array[NamedReference] = { arguments.collect { case named: NamedReference => named } } - override def arguments: Array[Expression] = { - if (sortedColumns.nonEmpty) { - (columns.toArray :+ numBuckets) ++ sortedColumns - } else { - numBuckets +: columns.toArray - } - } + override def arguments: Array[Expression] = (columns.toArray :+ numBuckets) ++ sortedColumns override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})" override def withReferences(newReferences: Seq[NamedReference]): Transform = { - if (sortedColumns.isEmpty) { - this.copy(columns = newReferences) - } else { - val splits = newReferences.grouped(columns.length).toList - this.copy(columns = splits(0), sortedColumns = splits(1)) - } + this.copy(columns = newReferences.take(columns.length), + sortedColumns = newReferences.drop(columns.length)) } } -private[sql] object BucketTransform { +private[sql] object SortedBucketTransform { def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] = expr match { case transform: Transform => transform match { - case BucketTransform(n, FieldReference(parts), FieldReference(sortCols)) => + case SortedBucketTransform(n, FieldReference(parts), FieldReference(sortCols)) => Some((n, FieldReference(parts), FieldReference(sortCols))) case _ => None @@ -146,22 +180,19 @@ private[sql] object BucketTransform { def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] = transform match { - case NamedTransform("sortedBucket", s) => + case NamedTransform("sortedBucket", arguments) => var index: Int = -1 var posOfLit: Int = -1 var numOfBucket: Int = -1 - s.foreach { + arguments.foreach { case Lit(value: Int, IntegerType) => numOfBucket = value index = index + 1 posOfLit = index case _ => index = index + 1 } - val splits = s.splitAt(posOfLit) - Some(numOfBucket, FieldReference( - splits._1.map(_.describe)), FieldReference(splits._2.drop(1).map(_.describe))) - case NamedTransform("bucket", Seq(Lit(value: Int, IntegerType), Ref(partCols: Seq[String]))) => - Some((value, FieldReference(partCols), FieldReference(Seq.empty[String]))) + Some(numOfBucket, FieldReference(arguments.take(posOfLit).map(_.describe)), + FieldReference(arguments.drop(posOfLit + 1).map(_.describe))) case _ => None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index fa8be1b8fa3c0..6b6fc4cd87801 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -80,6 +80,7 @@ class InMemoryTable( case _: DaysTransform => case _: HoursTransform => case _: BucketTransform => + case _: SortedBucketTransform => case t if !allowUnsupportedTransforms => throw new IllegalArgumentException(s"Transform $t is not a supported transform") } @@ -161,7 +162,11 @@ class InMemoryTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } - case BucketTransform(numBuckets, ref, _) => + case BucketTransform(numBuckets, ref) => + val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) + val valueHashCode = if (value == null) 0 else value.hashCode + ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets + case SortedBucketTransform(numBuckets, ref, _) => val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) val valueHashCode = if (value == null) 0 else value.hashCode ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index ec7919dd8959b..8c895afc773b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connector.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst -import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket +import org.apache.spark.sql.connector.expressions.LogicalExpressions.{bucket, sortedBucket} import org.apache.spark.sql.types.DataType class TransformExtractorSuite extends SparkFunSuite { @@ -140,7 +140,7 @@ class TransformExtractorSuite extends SparkFunSuite { } bucketTransform match { - case BucketTransform(numBuckets, FieldReference(seq), _) => + case BucketTransform(numBuckets, FieldReference(seq)) => assert(numBuckets === 16) assert(seq === Seq("a", "b")) case _ => @@ -148,7 +148,7 @@ class TransformExtractorSuite extends SparkFunSuite { } transform("unknown", ref("a", "b")) match { - case BucketTransform(_, _, _) => + case BucketTransform(_, _) => fail("Matched unknown transform") case _ => // expected @@ -168,7 +168,7 @@ class TransformExtractorSuite extends SparkFunSuite { } sortedBucketTransform match { - case BucketTransform(numBuckets, FieldReference(seq), FieldReference(sorted)) => + case SortedBucketTransform(numBuckets, FieldReference(seq), FieldReference(sorted)) => assert(numBuckets === 16) assert(seq === Seq("a", "b")) assert(sorted === Seq("c", "d")) @@ -194,7 +194,7 @@ class TransformExtractorSuite extends SparkFunSuite { val copied1 = bucketTransform.withReferences(reference1) assert(copied1.equals(bucketTransform)) - val sortedBucketTransform = bucket(16, col, sortedCol) + val sortedBucketTransform = sortedBucket(16, col, sortedCol) val reference2 = sortedBucketTransform.references assert(reference2.length == 4) assert(reference2(0).fieldNames() === Seq("a")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index d5547c1f3c1e3..ab334549f38ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlread import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils, SessionCatalog} import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogV2Util, Identifier, NamespaceChange, SupportsNamespaces, Table, TableCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.catalog.NamespaceChange.RemoveProperty -import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, SortedBucketTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.types.StructType @@ -318,7 +318,11 @@ private[sql] object V2SessionCatalog { case IdentityTransform(FieldReference(Seq(col))) => identityCols += col - case BucketTransform(numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) => + case BucketTransform(numBuckets, FieldReference(Seq(col))) => + bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, Nil)) + + case SortedBucketTransform( + numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) => bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil)) case transform => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index f1a8a2751f714..86dfe7c324870 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1569,7 +1569,6 @@ class DataSourceV2SQLSuite sql(s"CREATE TABLE $identifier (a int, b string, c int, d int, e int, f int) USING" + s" $v2Source PARTITIONED BY (a, b) CLUSTERED BY (c, d) SORTED by (e, f) INTO 4 BUCKETS") val describe = spark.sql(s"DESCRIBE $identifier") - describe.show(false) val part1 = describe .filter("col_name = 'Part 0'") .select("data_type").head.getString(0) From 61dc79590d07061d37192131b1e6aea35d55887a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 6 Jan 2022 17:45:36 -0800 Subject: [PATCH 7/9] address comments --- .../catalog/CatalogV2Implicits.scala | 2 +- .../connector/expressions/expressions.scala | 66 +++++++------------ .../sql/connector/catalog/InMemoryTable.scala | 6 +- .../expressions/TransformExtractorSuite.scala | 12 ++-- .../datasources/v2/V2SessionCatalog.scala | 14 ++-- .../sql/connector/DataSourceV2SQLSuite.scala | 2 +- 6 files changed, 38 insertions(+), 64 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 605ae2c5075f4..b5e38659724e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -41,7 +41,7 @@ private[sql] object CatalogV2Implicits { val references = spec.bucketColumnNames.map(col => reference(Seq(col))) if (spec.sortColumnNames.nonEmpty) { val sortedCol = spec.sortColumnNames.map(col => reference(Seq(col))) - sortedBucket(spec.numBuckets, references.toArray, sortedCol.toArray) + bucket(spec.numBuckets, references.toArray, sortedCol.toArray) } else { bucket(spec.numBuckets, references.toArray) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index 238a51412ffda..712964934816d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -45,7 +45,7 @@ private[sql] object LogicalExpressions { def bucket(numBuckets: Int, references: Array[NamedReference]): BucketTransform = BucketTransform(literal(numBuckets, IntegerType), references) - def sortedBucket( + def bucket( numBuckets: Int, references: Array[NamedReference], sortedCols: Array[NamedReference]): SortedBucketTransform = @@ -121,11 +121,11 @@ private[sql] final case class BucketTransform( } private[sql] object BucketTransform { - def unapply(expr: Expression): Option[(Int, FieldReference)] = expr match { + def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] = expr match { case transform: Transform => transform match { - case BucketTransform(n, FieldReference(parts)) => - Some((n, FieldReference(parts))) + case BucketTransform(n, FieldReference(parts), _) => + Some((n, FieldReference(parts), FieldReference(Seq.empty[String]))) case _ => None } @@ -133,11 +133,23 @@ private[sql] object BucketTransform { None } - def unapply(transform: Transform): Option[(Int, NamedReference)] = transform match { - case NamedTransform("bucket", Seq( - Lit(value: Int, IntegerType), - Ref(seq: Seq[String]))) => - Some((value, FieldReference(seq))) + def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] = + transform match { + case NamedTransform("sorted_bucket", arguments) => + var index: Int = -1 + var posOfLit: Int = -1 + var numOfBucket: Int = -1 + arguments.foreach { + case Lit(value: Int, IntegerType) => + numOfBucket = value + index = index + 1 + posOfLit = index + case _ => index = index + 1 + } + Some(numOfBucket, FieldReference(arguments.take(posOfLit).map(_.describe)), + FieldReference(arguments.drop(posOfLit + 1).map(_.describe))) + case NamedTransform("bucket", Seq(Lit(value: Int, IntegerType), Ref(seq: Seq[String]))) => + Some(value, FieldReference(seq), FieldReference(Seq.empty[String])) case _ => None } @@ -148,7 +160,7 @@ private[sql] final case class SortedBucketTransform( columns: Seq[NamedReference], sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform { - override val name: String = "sortedBucket" + override val name: String = "sorted_bucket" override def references: Array[NamedReference] = { arguments.collect { case named: NamedReference => named } @@ -164,40 +176,6 @@ private[sql] final case class SortedBucketTransform( } } -private[sql] object SortedBucketTransform { - def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] = - expr match { - case transform: Transform => - transform match { - case SortedBucketTransform(n, FieldReference(parts), FieldReference(sortCols)) => - Some((n, FieldReference(parts), FieldReference(sortCols))) - case _ => - None - } - case _ => - None - } - - def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] = - transform match { - case NamedTransform("sortedBucket", arguments) => - var index: Int = -1 - var posOfLit: Int = -1 - var numOfBucket: Int = -1 - arguments.foreach { - case Lit(value: Int, IntegerType) => - numOfBucket = value - index = index + 1 - posOfLit = index - case _ => index = index + 1 - } - Some(numOfBucket, FieldReference(arguments.take(posOfLit).map(_.describe)), - FieldReference(arguments.drop(posOfLit + 1).map(_.describe))) - case _ => - None - } -} - private[sql] final case class ApplyTransform( name: String, args: Seq[Expression]) extends Transform { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 6b6fc4cd87801..c0318c20b374e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -162,11 +162,7 @@ class InMemoryTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } - case BucketTransform(numBuckets, ref) => - val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) - val valueHashCode = if (value == null) 0 else value.hashCode - ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets - case SortedBucketTransform(numBuckets, ref, _) => + case BucketTransform(numBuckets, ref, _) => val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) val valueHashCode = if (value == null) 0 else value.hashCode ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index 8c895afc773b6..deed48755644a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connector.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst -import org.apache.spark.sql.connector.expressions.LogicalExpressions.{bucket, sortedBucket} +import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket import org.apache.spark.sql.types.DataType class TransformExtractorSuite extends SparkFunSuite { @@ -140,7 +140,7 @@ class TransformExtractorSuite extends SparkFunSuite { } bucketTransform match { - case BucketTransform(numBuckets, FieldReference(seq)) => + case BucketTransform(numBuckets, FieldReference(seq), _) => assert(numBuckets === 16) assert(seq === Seq("a", "b")) case _ => @@ -148,7 +148,7 @@ class TransformExtractorSuite extends SparkFunSuite { } transform("unknown", ref("a", "b")) match { - case BucketTransform(_, _) => + case BucketTransform(_, _, _) => fail("Matched unknown transform") case _ => // expected @@ -160,7 +160,7 @@ class TransformExtractorSuite extends SparkFunSuite { val sortedCol = Array(ref("c"), ref("d")) val sortedBucketTransform = new Transform { - override def name: String = "sortedBucket" + override def name: String = "sorted_bucket" override def references: Array[NamedReference] = col ++ sortedCol override def arguments: Array[Expression] = (col :+ lit(16)) ++ sortedCol override def describe: String = s"bucket(16, ${col(0).describe}, ${col(1).describe} " + @@ -168,7 +168,7 @@ class TransformExtractorSuite extends SparkFunSuite { } sortedBucketTransform match { - case SortedBucketTransform(numBuckets, FieldReference(seq), FieldReference(sorted)) => + case BucketTransform(numBuckets, FieldReference(seq), FieldReference(sorted)) => assert(numBuckets === 16) assert(seq === Seq("a", "b")) assert(sorted === Seq("c", "d")) @@ -194,7 +194,7 @@ class TransformExtractorSuite extends SparkFunSuite { val copied1 = bucketTransform.withReferences(reference1) assert(copied1.equals(bucketTransform)) - val sortedBucketTransform = sortedBucket(16, col, sortedCol) + val sortedBucketTransform = bucket(16, col, sortedCol) val reference2 = sortedBucketTransform.references assert(reference2.length == 4) assert(reference2(0).fieldNames() === Seq("a")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index ab334549f38ac..103b0b5907614 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlread import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils, SessionCatalog} import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogV2Util, Identifier, NamespaceChange, SupportsNamespaces, Table, TableCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.catalog.NamespaceChange.RemoveProperty -import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, SortedBucketTransform, Transform} +import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.types.StructType @@ -318,12 +318,12 @@ private[sql] object V2SessionCatalog { case IdentityTransform(FieldReference(Seq(col))) => identityCols += col - case BucketTransform(numBuckets, FieldReference(Seq(col))) => - bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, Nil)) - - case SortedBucketTransform( - numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) => - bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil)) + case BucketTransform(numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) => + if (sortCol.isEmpty) { + bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, Nil)) + } else { + bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil)) + } case transform => throw QueryExecutionErrors.unsupportedPartitionTransformError(transform) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 86dfe7c324870..70cc90faa1ac7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1580,7 +1580,7 @@ class DataSourceV2SQLSuite val part3 = describe .filter("col_name = 'Part 2'") .select("data_type").head.getString(0) - assert(part3 === "sortedBucket(c, d, 4, e, f)") + assert(part3 === "sorted_bucket(c, d, 4, e, f)") } } From 77b2c1226523c9bb579c2bc4832035a060eaa1a8 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Sun, 9 Jan 2022 16:46:59 -0800 Subject: [PATCH 8/9] address comments --- .../connector/expressions/expressions.scala | 39 ++++++++----------- .../sql/connector/catalog/InMemoryTable.scala | 13 +++++-- .../expressions/TransformExtractorSuite.scala | 10 ++--- .../datasources/v2/V2SessionCatalog.scala | 7 ++-- .../sql/connector/DataSourceV2SQLSuite.scala | 15 +++++-- 5 files changed, 46 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index 712964934816d..9402bceafbc5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.expressions +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types.{DataType, IntegerType, StringType} @@ -121,35 +122,29 @@ private[sql] final case class BucketTransform( } private[sql] object BucketTransform { - def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] = expr match { - case transform: Transform => - transform match { - case BucketTransform(n, FieldReference(parts), _) => - Some((n, FieldReference(parts), FieldReference(Seq.empty[String]))) - case _ => - None - } - case _ => - None - } - - def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] = + def unapply(transform: Transform): Option[(Int, Seq[NamedReference], Seq[NamedReference])] = transform match { case NamedTransform("sorted_bucket", arguments) => - var index: Int = -1 var posOfLit: Int = -1 var numOfBucket: Int = -1 - arguments.foreach { + arguments.zipWithIndex.foreach { + case (Lit(value: Int, IntegerType), i) => + numOfBucket = value + posOfLit = i + case _ => + } + Some(numOfBucket, arguments.take(posOfLit).map(_.asInstanceOf[NamedReference]), + arguments.drop(posOfLit + 1).map(_.asInstanceOf[NamedReference])) + case NamedTransform("bucket", arguments) => + var numOfBucket: Int = -1 + arguments(0) match { case Lit(value: Int, IntegerType) => numOfBucket = value - index = index + 1 - posOfLit = index - case _ => index = index + 1 + case _ => throw new SparkException("The first element in BucketTransform arguments " + + "should be an Integer Literal.") } - Some(numOfBucket, FieldReference(arguments.take(posOfLit).map(_.describe)), - FieldReference(arguments.drop(posOfLit + 1).map(_.describe))) - case NamedTransform("bucket", Seq(Lit(value: Int, IntegerType), Ref(seq: Seq[String]))) => - Some(value, FieldReference(seq), FieldReference(Seq.empty[String])) + Some(numOfBucket, arguments.drop(1).map(_.asInstanceOf[NamedReference]), + Seq.empty[FieldReference]) case _ => None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index c0318c20b374e..8e5e920d89abe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -162,10 +162,15 @@ class InMemoryTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } - case BucketTransform(numBuckets, ref, _) => - val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) - val valueHashCode = if (value == null) 0 else value.hashCode - ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets + case BucketTransform(numBuckets, cols, _) => + val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row)) + var valueHashCode = 0 + valueTypePairs.foreach( pair => + if ( pair._1 != null) valueHashCode += pair._1.hashCode() + ) + var dataTypeHashCode = 0 + valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode()) + ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index deed48755644a..54ab1df3fa8f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -140,9 +140,9 @@ class TransformExtractorSuite extends SparkFunSuite { } bucketTransform match { - case BucketTransform(numBuckets, FieldReference(seq), _) => + case BucketTransform(numBuckets, cols, _) => assert(numBuckets === 16) - assert(seq === Seq("a", "b")) + assert(cols(0).fieldNames === Seq("a", "b")) case _ => fail("Did not match BucketTransform extractor") } @@ -168,10 +168,10 @@ class TransformExtractorSuite extends SparkFunSuite { } sortedBucketTransform match { - case BucketTransform(numBuckets, FieldReference(seq), FieldReference(sorted)) => + case BucketTransform(numBuckets, cols, sortCols) => assert(numBuckets === 16) - assert(seq === Seq("a", "b")) - assert(sorted === Seq("c", "d")) + assert(cols.flatMap(c => c.fieldNames()) === Seq("a", "b")) + assert(sortCols.flatMap(c => c.fieldNames()) === Seq("c", "d")) case _ => fail("Did not match BucketTransform extractor") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index 103b0b5907614..f3833f53dcf63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -318,11 +318,12 @@ private[sql] object V2SessionCatalog { case IdentityTransform(FieldReference(Seq(col))) => identityCols += col - case BucketTransform(numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) => + case BucketTransform(numBuckets, col, sortCol) => if (sortCol.isEmpty) { - bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, Nil)) + bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")), Nil)) } else { - bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil)) + bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")), + sortCol.map(_.fieldNames.mkString(".")))) } case transform => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 70cc90faa1ac7..ec9360dc55c90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -410,9 +410,12 @@ class DataSourceV2SQLSuite test("SPARK-36850: CreateTableAsSelect partitions can be specified using " + "PARTITIONED BY and/or CLUSTERED BY") { val identifier = "testcat.table_name" + val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"), + (3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4") + df.createOrReplaceTempView("source_table") withTable(identifier) { spark.sql(s"CREATE TABLE $identifier USING foo PARTITIONED BY (id) " + - s"CLUSTERED BY (data) INTO 4 BUCKETS AS SELECT * FROM source") + s"CLUSTERED BY (data1, data2, data3, data4) INTO 4 BUCKETS AS SELECT * FROM source_table") val describe = spark.sql(s"DESCRIBE $identifier") val part1 = describe .filter("col_name = 'Part 0'") @@ -421,18 +424,22 @@ class DataSourceV2SQLSuite val part2 = describe .filter("col_name = 'Part 1'") .select("data_type").head.getString(0) - assert(part2 === "bucket(4, data)") + assert(part2 === "bucket(4, data1, data2, data3, data4)") } } test("SPARK-36850: ReplaceTableAsSelect partitions can be specified using " + "PARTITIONED BY and/or CLUSTERED BY") { val identifier = "testcat.table_name" + val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"), + (3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4") + df.createOrReplaceTempView("source_table") withTable(identifier) { spark.sql(s"CREATE TABLE $identifier USING foo " + "AS SELECT id FROM source") spark.sql(s"REPLACE TABLE $identifier USING foo PARTITIONED BY (id) " + - s"CLUSTERED BY (data) INTO 4 BUCKETS AS SELECT * FROM source") + s"CLUSTERED BY (data1, data2) SORTED by (data3, data4) INTO 4 BUCKETS " + + s"AS SELECT * FROM source_table") val describe = spark.sql(s"DESCRIBE $identifier") val part1 = describe .filter("col_name = 'Part 0'") @@ -441,7 +448,7 @@ class DataSourceV2SQLSuite val part2 = describe .filter("col_name = 'Part 1'") .select("data_type").head.getString(0) - assert(part2 === "bucket(4, data)") + assert(part2 === "sorted_bucket(data1, data2, 4, data3, data4)") } } From dac56935bc71cb60c67ae8cbb9f299ed3f164a6d Mon Sep 17 00:00:00 2001 From: huaxingao Date: Wed, 12 Jan 2022 10:54:00 -0800 Subject: [PATCH 9/9] Trigger Build