diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsReportDistinctKeys.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsReportDistinctKeys.java
new file mode 100644
index 0000000000000..7e7f806d24ad4
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsReportDistinctKeys.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import java.util.Set;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+
+/**
+ * A mix in interface for {@link Scan}. Data sources can implement this interface to
+ * report unique keys set to Spark.
+ *
+ * Spark will optimize the query plan according to the given unique keys.
+ * For example, Spark will eliminate the `Distinct` if the v2 relation only output the unique
+ * attributes.
+ *
+ * Distinct
+ * +- RelationV2[unique_key#1]
+ *
+ *
+ * Note that, Spark doest not validate whether the value is unique or not. The implementation
+ * should guarantee this.
+ *
+ * @since 3.4.0
+ */
+@Evolving
+public interface SupportsReportDistinctKeys extends Scan {
+ /**
+ * Returns a set of unique keys. Each unique keys can consist of multiple attributes.
+ */
+ Set> distinctKeysSet();
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
index c252ea5ccfe03..d93c4ec0a39a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
@@ -49,6 +49,10 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
refs.map(ref => resolveRef[T](ref, plan))
}
+ def resolveRefs[T <: NamedExpression](refs: Set[NamedReference], plan: LogicalPlan): Set[T] = {
+ refs.map(ref => resolveRef[T](ref, plan))
+ }
+
/**
* Converts the array of input V2 [[V2SortOrder]] into their counterparts in catalyst.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala
index 1f495688bc5e3..4c2cbfdba6df7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala
@@ -62,7 +62,10 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] {
}
}
- override def default(p: LogicalPlan): Set[ExpressionSet] = Set.empty[ExpressionSet]
+ override def default(p: LogicalPlan): Set[ExpressionSet] = p match {
+ case leaf: LeafNode => leaf.reportDistinctKeysSet()
+ case _ => Set.empty[ExpressionSet]
+ }
override def visitAggregate(p: Aggregate): Set[ExpressionSet] = {
// handle group by a, a and global aggregate
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 7640d9234c71f..6ba011c4ddc60 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -167,6 +167,9 @@ abstract class LogicalPlan
trait LeafNode extends LogicalPlan with LeafLike[LogicalPlan] {
override def producedAttributes: AttributeSet = outputSet
+ /** Return a set of unique keys. */
+ def reportDistinctKeysSet(): Set[ExpressionSet] = Set.empty[ExpressionSet]
+
/** Leaf nodes that can survive analysis must define their own statistics. */
def computeStats(): Statistics = throw new UnsupportedOperationException
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
index 2045c59933739..cad0a581ec5fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
@@ -17,12 +17,14 @@
package org.apache.spark.sql.execution.datasources.v2
+import scala.collection.JavaConverters._
+
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation}
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, ExpressionSet, SortOrder, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability}
-import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics}
+import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportDistinctKeys, SupportsReportStatistics}
import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils
@@ -75,6 +77,13 @@ case class DataSourceV2Relation(
s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $qualifiedTableName $name"
}
+ override def reportDistinctKeysSet(): Set[ExpressionSet] = {
+ table.asReadable.newScanBuilder(options).build() match {
+ case r: SupportsReportDistinctKeys => DataSourceV2Relation.transformUniqueKeysSet(r, this)
+ case _ => super.reportDistinctKeysSet()
+ }
+ }
+
override def computeStats(): Statistics = {
if (Utils.isTesting) {
// when testing, throw an exception if this computeStats method is called because stats should
@@ -134,6 +143,11 @@ case class DataSourceV2ScanRelation(
s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name"
}
+ override def reportDistinctKeysSet(): Set[ExpressionSet] = scan match {
+ case r: SupportsReportDistinctKeys => DataSourceV2Relation.transformUniqueKeysSet(r, this)
+ case _ => super.reportDistinctKeysSet()
+ }
+
override def computeStats(): Statistics = {
scan match {
case r: SupportsReportStatistics =>
@@ -166,6 +180,11 @@ case class StreamingDataSourceV2Relation(
override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance()))
+ override def reportDistinctKeysSet(): Set[ExpressionSet] = scan match {
+ case r: SupportsReportDistinctKeys => DataSourceV2Relation.transformUniqueKeysSet(r, this)
+ case _ => super.reportDistinctKeysSet()
+ }
+
override def computeStats(): Statistics = scan match {
case r: SupportsReportStatistics =>
val statistics = r.estimateStatistics()
@@ -220,4 +239,13 @@ object DataSourceV2Relation {
sizeInBytes = v2Statistics.sizeInBytes().orElse(defaultSizeInBytes),
rowCount = numRows)
}
+
+ def transformUniqueKeysSet(
+ r: SupportsReportDistinctKeys,
+ p: LogicalPlan): Set[ExpressionSet] = {
+ val uniqueKeysSet = r.distinctKeysSet().asScala
+ uniqueKeysSet.map { uniqueKeys =>
+ ExpressionSet(V2ExpressionUtils.resolveRefs(uniqueKeys.asScala.toSet, p))
+ }.toSet
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index 3255dee0a16b0..e4e6ba4cd845d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -24,6 +24,7 @@ import java.util.OptionalLong
import scala.collection.mutable
+import com.google.common.collect.Sets
import org.scalatest.Assertions._
import org.apache.spark.sql.catalyst.InternalRow
@@ -272,10 +273,22 @@ class InMemoryTable(
var data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
- extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning {
+ extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning
+ with SupportsReportDistinctKeys {
override def toBatch: Batch = this
+ override def distinctKeysSet(): java.util.Set[java.util.Set[NamedReference]] = {
+ val uniqueKeys = readSchema.fields.collect {
+ case f if f.metadata.contains("unique") => f.name
+ } .map(FieldReference(_))
+ .map(Sets.newHashSet(_))
+
+ Sets.newHashSet(
+ uniqueKeys: _*
+ )
+ }
+
override def estimateStatistics(): Statistics = {
if (data.isEmpty) {
return InMemoryStats(OptionalLong.of(0L), OptionalLong.of(0L))
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportDistinctKeysDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportDistinctKeysDataSource.java
new file mode 100644
index 0000000000000..efe0f1e6dac7e
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportDistinctKeysDataSource.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.connector;
+
+import com.google.common.collect.Sets;
+import org.apache.spark.sql.connector.TestingV2Source;
+import org.apache.spark.sql.connector.catalog.Table;
+import org.apache.spark.sql.connector.expressions.FieldReference;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.read.*;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+
+import java.util.Set;
+
+public class JavaReportDistinctKeysDataSource implements TestingV2Source {
+ static class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportDistinctKeys {
+ @Override
+ public Set> distinctKeysSet() {
+ return Sets.newHashSet(
+ Sets.newHashSet(FieldReference.apply("i")),
+ Sets.newHashSet(FieldReference.apply("j")));
+ }
+
+ @Override
+ public InputPartition[] planInputPartitions() {
+ InputPartition[] partitions = new InputPartition[1];
+ partitions[0] = new JavaRangeInputPartition(0, 1);
+ return partitions;
+ }
+ }
+
+ @Override
+ public Table getTable(CaseInsensitiveStringMap options) {
+ return new JavaSimpleBatchTable() {
+ @Override
+ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
+ return new JavaReportDistinctKeysDataSource.MyScanBuilder();
+ }
+ };
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
index 158e1634d58c5..29c936cfd2bc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
@@ -21,12 +21,12 @@ import java.util.Collections
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect}
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{IntegerType, Metadata, StructField, StructType}
import org.apache.spark.sql.util.QueryExecutionListener
class DataSourceV2DataFrameSuite
@@ -253,4 +253,22 @@ class DataSourceV2DataFrameSuite
spark.listenerManager.unregister(listener)
}
}
+
+ test("SPARK-38932: Datasource v2 support report distinct keys") {
+ val t = "testcat.unique.t"
+ withTable(t) {
+ val unique = """ {"unique":""} """.stripMargin
+ val schema = StructType(
+ StructField("key", IntegerType, metadata = Metadata.fromJson(unique)) ::
+ StructField("value", IntegerType) :: Nil)
+ val data = spark.sparkContext.parallelize(Row(1, 1) :: Row(2, 1) :: Nil)
+ spark.createDataFrame(data, schema).writeTo(t).create()
+
+ val qe = spark.table(t).groupBy($"key").agg($"key").queryExecution
+ val analyzed = qe.analyzed
+ val optimized = qe.optimizedPlan
+ assert(analyzed.exists(_.isInstanceOf[Aggregate]))
+ assert(!optimized.exists(_.isInstanceOf[Aggregate]))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index 5c4be75e02c7f..67ccdbce0646d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -20,10 +20,13 @@ package org.apache.spark.sql.connector
import java.io.File
import java.util.OptionalLong
+import com.google.common.collect.Sets
import test.org.apache.spark.sql.connector._
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.ExpressionSet
+import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform}
@@ -596,6 +599,30 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
}
}
}
+
+ test("SPARK-38932: Datasource v2 support report distinct keys") {
+ def checkUniqueKeys(leaf: LeafNode): Unit = {
+ // Assume all output attributes are unique keys
+ val expected1 = leaf.output.map(attr => ExpressionSet(attr :: Nil)).toSet
+ assert(leaf.reportDistinctKeysSet() == expected1)
+ }
+
+ Seq(classOf[ReportUniqueKeysDataSource], classOf[JavaReportDistinctKeysDataSource]).foreach {
+ cls =>
+ withClue(cls.getName) {
+ val df = spark.read.format(cls.getName).load()
+ val analyzed = df.queryExecution.analyzed.collect {
+ case d: DataSourceV2Relation => d
+ }.head
+ checkUniqueKeys(analyzed)
+
+ val optimized = df.queryExecution.optimizedPlan.collect {
+ case d: DataSourceV2ScanRelation => d
+ }.head
+ checkUniqueKeys(optimized)
+ }
+ }
+ }
}
@@ -1106,3 +1133,26 @@ class ReportStatisticsDataSource extends SimpleWritableDataSource {
}
}
}
+
+class ReportUniqueKeysDataSource extends SimpleWritableDataSource {
+
+ class MyScanBuilder extends SimpleScanBuilder with SupportsReportDistinctKeys {
+ override def distinctKeysSet(): java.util.Set[java.util.Set[NamedReference]] = {
+ Sets.newHashSet(
+ Sets.newHashSet(FieldReference.apply("i")),
+ Sets.newHashSet(FieldReference.apply("j")))
+ }
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ Array(RangeInputPartition(0, 1))
+ }
+ }
+
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
+ new SimpleBatchTable {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new MyScanBuilder
+ }
+ }
+ }
+}