diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsReportDistinctKeys.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsReportDistinctKeys.java new file mode 100644 index 0000000000000..7e7f806d24ad4 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsReportDistinctKeys.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import java.util.Set; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A mix in interface for {@link Scan}. Data sources can implement this interface to + * report unique keys set to Spark. + *

+ * Spark will optimize the query plan according to the given unique keys. + * For example, Spark will eliminate the `Distinct` if the v2 relation only output the unique + * attributes. + *

+ *   Distinct
+ *     +- RelationV2[unique_key#1]
+ * 
+ *

+ * Note that, Spark doest not validate whether the value is unique or not. The implementation + * should guarantee this. + * + * @since 3.4.0 + */ +@Evolving +public interface SupportsReportDistinctKeys extends Scan { + /** + * Returns a set of unique keys. Each unique keys can consist of multiple attributes. + */ + Set> distinctKeysSet(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index c252ea5ccfe03..d93c4ec0a39a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -49,6 +49,10 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { refs.map(ref => resolveRef[T](ref, plan)) } + def resolveRefs[T <: NamedExpression](refs: Set[NamedReference], plan: LogicalPlan): Set[T] = { + refs.map(ref => resolveRef[T](ref, plan)) + } + /** * Converts the array of input V2 [[V2SortOrder]] into their counterparts in catalyst. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala index 1f495688bc5e3..4c2cbfdba6df7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala @@ -62,7 +62,10 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] { } } - override def default(p: LogicalPlan): Set[ExpressionSet] = Set.empty[ExpressionSet] + override def default(p: LogicalPlan): Set[ExpressionSet] = p match { + case leaf: LeafNode => leaf.reportDistinctKeysSet() + case _ => Set.empty[ExpressionSet] + } override def visitAggregate(p: Aggregate): Set[ExpressionSet] = { // handle group by a, a and global aggregate diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 7640d9234c71f..6ba011c4ddc60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -167,6 +167,9 @@ abstract class LogicalPlan trait LeafNode extends LogicalPlan with LeafLike[LogicalPlan] { override def producedAttributes: AttributeSet = outputSet + /** Return a set of unique keys. */ + def reportDistinctKeysSet(): Set[ExpressionSet] = Set.empty[ExpressionSet] + /** Leaf nodes that can survive analysis must define their own statistics. */ def computeStats(): Statistics = throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 2045c59933739..cad0a581ec5fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, ExpressionSet, SortOrder, V2ExpressionUtils} import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability} -import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics} +import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportDistinctKeys, SupportsReportStatistics} import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils @@ -75,6 +77,13 @@ case class DataSourceV2Relation( s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $qualifiedTableName $name" } + override def reportDistinctKeysSet(): Set[ExpressionSet] = { + table.asReadable.newScanBuilder(options).build() match { + case r: SupportsReportDistinctKeys => DataSourceV2Relation.transformUniqueKeysSet(r, this) + case _ => super.reportDistinctKeysSet() + } + } + override def computeStats(): Statistics = { if (Utils.isTesting) { // when testing, throw an exception if this computeStats method is called because stats should @@ -134,6 +143,11 @@ case class DataSourceV2ScanRelation( s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name" } + override def reportDistinctKeysSet(): Set[ExpressionSet] = scan match { + case r: SupportsReportDistinctKeys => DataSourceV2Relation.transformUniqueKeysSet(r, this) + case _ => super.reportDistinctKeysSet() + } + override def computeStats(): Statistics = { scan match { case r: SupportsReportStatistics => @@ -166,6 +180,11 @@ case class StreamingDataSourceV2Relation( override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) + override def reportDistinctKeysSet(): Set[ExpressionSet] = scan match { + case r: SupportsReportDistinctKeys => DataSourceV2Relation.transformUniqueKeysSet(r, this) + case _ => super.reportDistinctKeysSet() + } + override def computeStats(): Statistics = scan match { case r: SupportsReportStatistics => val statistics = r.estimateStatistics() @@ -220,4 +239,13 @@ object DataSourceV2Relation { sizeInBytes = v2Statistics.sizeInBytes().orElse(defaultSizeInBytes), rowCount = numRows) } + + def transformUniqueKeysSet( + r: SupportsReportDistinctKeys, + p: LogicalPlan): Set[ExpressionSet] = { + val uniqueKeysSet = r.distinctKeysSet().asScala + uniqueKeysSet.map { uniqueKeys => + ExpressionSet(V2ExpressionUtils.resolveRefs(uniqueKeys.asScala.toSet, p)) + }.toSet + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 3255dee0a16b0..e4e6ba4cd845d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -24,6 +24,7 @@ import java.util.OptionalLong import scala.collection.mutable +import com.google.common.collect.Sets import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow @@ -272,10 +273,22 @@ class InMemoryTable( var data: Seq[InputPartition], readSchema: StructType, tableSchema: StructType) - extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning { + extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning + with SupportsReportDistinctKeys { override def toBatch: Batch = this + override def distinctKeysSet(): java.util.Set[java.util.Set[NamedReference]] = { + val uniqueKeys = readSchema.fields.collect { + case f if f.metadata.contains("unique") => f.name + } .map(FieldReference(_)) + .map(Sets.newHashSet(_)) + + Sets.newHashSet( + uniqueKeys: _* + ) + } + override def estimateStatistics(): Statistics = { if (data.isEmpty) { return InMemoryStats(OptionalLong.of(0L), OptionalLong.of(0L)) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportDistinctKeysDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportDistinctKeysDataSource.java new file mode 100644 index 0000000000000..efe0f1e6dac7e --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportDistinctKeysDataSource.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.connector; + +import com.google.common.collect.Sets; +import org.apache.spark.sql.connector.TestingV2Source; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.*; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +import java.util.Set; + +public class JavaReportDistinctKeysDataSource implements TestingV2Source { + static class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportDistinctKeys { + @Override + public Set> distinctKeysSet() { + return Sets.newHashSet( + Sets.newHashSet(FieldReference.apply("i")), + Sets.newHashSet(FieldReference.apply("j"))); + } + + @Override + public InputPartition[] planInputPartitions() { + InputPartition[] partitions = new InputPartition[1]; + partitions[0] = new JavaRangeInputPartition(0, 1); + return partitions; + } + } + + @Override + public Table getTable(CaseInsensitiveStringMap options) { + return new JavaSimpleBatchTable() { + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + return new JavaReportDistinctKeysDataSource.MyScanBuilder(); + } + }; + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index 158e1634d58c5..29c936cfd2bc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -21,12 +21,12 @@ import java.util.Collections import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, Metadata, StructField, StructType} import org.apache.spark.sql.util.QueryExecutionListener class DataSourceV2DataFrameSuite @@ -253,4 +253,22 @@ class DataSourceV2DataFrameSuite spark.listenerManager.unregister(listener) } } + + test("SPARK-38932: Datasource v2 support report distinct keys") { + val t = "testcat.unique.t" + withTable(t) { + val unique = """ {"unique":""} """.stripMargin + val schema = StructType( + StructField("key", IntegerType, metadata = Metadata.fromJson(unique)) :: + StructField("value", IntegerType) :: Nil) + val data = spark.sparkContext.parallelize(Row(1, 1) :: Row(2, 1) :: Nil) + spark.createDataFrame(data, schema).writeTo(t).create() + + val qe = spark.table(t).groupBy($"key").agg($"key").queryExecution + val analyzed = qe.analyzed + val optimized = qe.optimizedPlan + assert(analyzed.exists(_.isInstanceOf[Aggregate])) + assert(!optimized.exists(_.isInstanceOf[Aggregate])) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 5c4be75e02c7f..67ccdbce0646d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -20,10 +20,13 @@ package org.apache.spark.sql.connector import java.io.File import java.util.OptionalLong +import com.google.common.collect.Sets import test.org.apache.spark.sql.connector._ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.ExpressionSet +import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform} @@ -596,6 +599,30 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } } + + test("SPARK-38932: Datasource v2 support report distinct keys") { + def checkUniqueKeys(leaf: LeafNode): Unit = { + // Assume all output attributes are unique keys + val expected1 = leaf.output.map(attr => ExpressionSet(attr :: Nil)).toSet + assert(leaf.reportDistinctKeysSet() == expected1) + } + + Seq(classOf[ReportUniqueKeysDataSource], classOf[JavaReportDistinctKeysDataSource]).foreach { + cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + val analyzed = df.queryExecution.analyzed.collect { + case d: DataSourceV2Relation => d + }.head + checkUniqueKeys(analyzed) + + val optimized = df.queryExecution.optimizedPlan.collect { + case d: DataSourceV2ScanRelation => d + }.head + checkUniqueKeys(optimized) + } + } + } } @@ -1106,3 +1133,26 @@ class ReportStatisticsDataSource extends SimpleWritableDataSource { } } } + +class ReportUniqueKeysDataSource extends SimpleWritableDataSource { + + class MyScanBuilder extends SimpleScanBuilder with SupportsReportDistinctKeys { + override def distinctKeysSet(): java.util.Set[java.util.Set[NamedReference]] = { + Sets.newHashSet( + Sets.newHashSet(FieldReference.apply("i")), + Sets.newHashSet(FieldReference.apply("j"))) + } + + override def planInputPartitions(): Array[InputPartition] = { + Array(RangeInputPartition(0, 1)) + } + } + + override def getTable(options: CaseInsensitiveStringMap): Table = { + new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new MyScanBuilder + } + } + } +}