From d46483ff8e493fbdebd73a0713afb7cfd6e44e8e Mon Sep 17 00:00:00 2001 From: mingjial Date: Thu, 13 Aug 2020 23:04:36 -0700 Subject: [PATCH 1/7] [SPARK-32609] Incorrect exchange reuse with DataSourceV2 --- .../datasources/v2/DataSourceV2ScanExec.scala | 3 ++- .../sql/sources/v2/DataSourceV2Suite.scala | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c8494f97f1761..9b70eecd2e831 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -47,7 +47,8 @@ case class DataSourceV2ScanExec( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: DataSourceV2ScanExec => - output == other.output && reader.getClass == other.reader.getClass && options == other.options + (output == other.output && reader.getClass == other.reader.getClass + && options == other.options && pushedFilters == other.pushedFilters) case _ => false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 2367bdd169522..028b8cad33634 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -371,6 +371,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-32609: DataSourceV2 with different pushedfilters should be different") { + def getScanExec(query: DataFrame): DataSourceV2ScanExec = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d + }.head + } + + Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + val q1 = df.select('i).filter('i > 6) + val q2 = df.select('i).filter('i > 5) + val scan1 = getScanExec(q1) + val scan2 = getScanExec(q2) + assert(!scan1.equals(scan2)) + } + } + } } class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { From 5b1b9b39eb612cbf9ec67efd4e364adafcff66c4 Mon Sep 17 00:00:00 2001 From: mingjial Date: Fri, 14 Aug 2020 15:21:03 -0700 Subject: [PATCH 2/7] [SPARK-32609] Incorrect exchange reuse with DataSourceV2 --- .../org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 028b8cad33634..ef0a8bdfd2fcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -384,9 +384,12 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val df = spark.read.format(cls.getName).load() val q1 = df.select('i).filter('i > 6) val q2 = df.select('i).filter('i > 5) + val q3 = df.select('i).filter('i > 5) val scan1 = getScanExec(q1) val scan2 = getScanExec(q2) + val scan3 = getScanExec(q3) assert(!scan1.equals(scan2)) + assert(scan2.equals(scan3)) } } } From dd0fb242277184abda5c6a4cb03bdec4e930e736 Mon Sep 17 00:00:00 2001 From: mingjial Date: Thu, 27 Aug 2020 11:12:56 -0700 Subject: [PATCH 3/7] [Spark 32708] Query optimization fails to reuse exchange with DataSourceV2 --- .../execution/datasources/v2/DataSourceV2Relation.scala | 5 +++-- .../execution/datasources/v2/DataSourceV2ScanExec.scala | 3 ++- .../execution/datasources/v2/DataSourceV2Strategy.scala | 8 ++++---- .../datasources/v2/DataSourceV2StringFormat.scala | 3 ++- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index abc5fb979250a..079327417abbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelat import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport} import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} import org.apache.spark.sql.sources.v2.writer.DataSourceWriter @@ -54,7 +55,7 @@ case class DataSourceV2Relation( tableIdent.map(_.unquotedString).getOrElse(s"${source.name}:unknown") } - override def pushedFilters: Seq[Expression] = Seq.empty + override def pushedFilters: Seq[Filter] = Seq.empty override def simpleString: String = "RelationV2 " + metadataString @@ -92,7 +93,7 @@ case class StreamingDataSourceV2Relation( override def simpleString: String = "Streaming RelationV2 " + metadataString - override def pushedFilters: Seq[Expression] = Nil + override def pushedFilters: Seq[Filter] = Nil override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c8494f97f1761..5bdc69a1ab116 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader @@ -38,7 +39,7 @@ case class DataSourceV2ScanExec( output: Seq[AttributeReference], @transient source: DataSourceV2, @transient options: Map[String, String], - @transient pushedFilters: Seq[Expression], + @transient pushedFilters: Seq[Filter], @transient reader: DataSourceReader) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 9d97d3b58f30c..365f3df121b80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Rep import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader @@ -38,11 +39,11 @@ object DataSourceV2Strategy extends Strategy { */ private def pushFilters( reader: DataSourceReader, - filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + filters: Seq[Expression]): (Seq[Filter], Seq[Expression]) = { reader match { case r: SupportsPushDownFilters => // A map from translated data source filters to original catalyst filter expressions. - val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] + val translatedFilterToExpr = mutable.HashMap.empty[Filter, Expression] // Catalyst filter expression that can't be translated to data source filters. val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] @@ -61,8 +62,7 @@ object DataSourceV2Strategy extends Strategy { val postScanFilters = r.pushFilters(translatedFilterToExpr.keys.toArray) .map(translatedFilterToExpr) // The filters which are marked as pushed to this data source - val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) - (pushedFilters, untranslatableExprs ++ postScanFilters) + (r.pushedFilters(), untranslatableExprs ++ postScanFilters) case _ => (Nil, filters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index 97e6c6d702acb..d0b123c197fdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -21,6 +21,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.util.Utils @@ -49,7 +50,7 @@ trait DataSourceV2StringFormat { /** * The filters which have been pushed to the data source. */ - def pushedFilters: Seq[Expression] + def pushedFilters: Seq[Filter] private def sourceName: String = source match { case registered: DataSourceRegister => registered.shortName() From 69ea44e347ba157c1cb973127908a056eb64d12f Mon Sep 17 00:00:00 2001 From: mingjial Date: Wed, 9 Sep 2020 20:36:18 -0700 Subject: [PATCH 4/7] [SPARK-32708] Query optimization fails to reuse exchange with DataSourceV2 --- .../datasources/v2/DataSourceV2ScanExec.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 9b70eecd2e831..2fe563a42b89a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} @@ -52,6 +53,17 @@ case class DataSourceV2ScanExec( case _ => false } + override def doCanonicalize(): DataSourceV2ScanExec = { + DataSourceV2ScanExec( + output.map(QueryPlan.normalizeExprId(_, output)), + source, + options, + QueryPlan.normalizePredicates( + pushedFilters, + AttributeSeq(pushedFilters.flatMap(_.references).distinct)), + reader) + } + override def hashCode(): Int = { Seq(output, source, options).hashCode() } From a6e4709fed67416a73f9b8635d8fc7be4e412754 Mon Sep 17 00:00:00 2001 From: mingjial Date: Wed, 9 Sep 2020 20:39:34 -0700 Subject: [PATCH 5/7] Revert "[Spark 32708] Query optimization fails to reuse exchange with DataSourceV2" This reverts commit dd0fb242277184abda5c6a4cb03bdec4e930e736. --- .../execution/datasources/v2/DataSourceV2Relation.scala | 5 ++--- .../execution/datasources/v2/DataSourceV2ScanExec.scala | 3 +-- .../execution/datasources/v2/DataSourceV2Strategy.scala | 8 ++++---- .../datasources/v2/DataSourceV2StringFormat.scala | 3 +-- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 079327417abbb..abc5fb979250a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelat import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport} import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} import org.apache.spark.sql.sources.v2.writer.DataSourceWriter @@ -55,7 +54,7 @@ case class DataSourceV2Relation( tableIdent.map(_.unquotedString).getOrElse(s"${source.name}:unknown") } - override def pushedFilters: Seq[Filter] = Seq.empty + override def pushedFilters: Seq[Expression] = Seq.empty override def simpleString: String = "RelationV2 " + metadataString @@ -93,7 +92,7 @@ case class StreamingDataSourceV2Relation( override def simpleString: String = "Streaming RelationV2 " + metadataString - override def pushedFilters: Seq[Filter] = Nil + override def pushedFilters: Seq[Expression] = Nil override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7fd0d63cbc475..2fe563a42b89a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader @@ -40,7 +39,7 @@ case class DataSourceV2ScanExec( output: Seq[AttributeReference], @transient source: DataSourceV2, @transient options: Map[String, String], - @transient pushedFilters: Seq[Filter], + @transient pushedFilters: Seq[Expression], @transient reader: DataSourceReader) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 365f3df121b80..9d97d3b58f30c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Rep import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader @@ -39,11 +38,11 @@ object DataSourceV2Strategy extends Strategy { */ private def pushFilters( reader: DataSourceReader, - filters: Seq[Expression]): (Seq[Filter], Seq[Expression]) = { + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { reader match { case r: SupportsPushDownFilters => // A map from translated data source filters to original catalyst filter expressions. - val translatedFilterToExpr = mutable.HashMap.empty[Filter, Expression] + val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] // Catalyst filter expression that can't be translated to data source filters. val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] @@ -62,7 +61,8 @@ object DataSourceV2Strategy extends Strategy { val postScanFilters = r.pushFilters(translatedFilterToExpr.keys.toArray) .map(translatedFilterToExpr) // The filters which are marked as pushed to this data source - (r.pushedFilters(), untranslatableExprs ++ postScanFilters) + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + (pushedFilters, untranslatableExprs ++ postScanFilters) case _ => (Nil, filters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index d0b123c197fdc..97e6c6d702acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -21,7 +21,6 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.util.Utils @@ -50,7 +49,7 @@ trait DataSourceV2StringFormat { /** * The filters which have been pushed to the data source. */ - def pushedFilters: Seq[Filter] + def pushedFilters: Seq[Expression] private def sourceName: String = source match { case registered: DataSourceRegister => registered.shortName() From 98483c82151634760a82fc860ac6de26d4b024ba Mon Sep 17 00:00:00 2001 From: mingjial Date: Thu, 10 Sep 2020 17:57:13 -0700 Subject: [PATCH 6/7] [SPARK-32708] Query optimization fails to reuse exchange with DataSourceV2 --- .../sql/sources/v2/DataSourceV2Suite.scala | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ef0a8bdfd2fcd..a09f61cc26a2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -393,6 +393,29 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-32708: same columns with different ExprIds should be equal after canonicalization ") { + def getV2ScanExecs(query: DataFrame): Seq[DataSourceV2ScanExec] = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d + } + } + + val df1 = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + val q1 = df1.select(($"i" + 1).as("k"), ($"i" - 1).as("j")).filter('i > 5) + val df2 = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + val q2 = df2.select(($"i" + 1).as("k"), ($"i" - 1).as("j")).filter('i > 5) + + val scans1 = getV2ScanExecs(q1.join(q2, "j")) + assert(scans1(0).sameResult(scans1(1))) + assert(scans1(0).doCanonicalize().equals(scans1(1).doCanonicalize())) + + val q3 = df2.select(($"i" + 1).as("k"), ($"i" - 1).as("j")).filter('i > 6) + val scans2 = getV2ScanExecs(q1.join(q3, "j")) + assert(!scans2(0).sameResult(scans2(1))) + assert(!scans2(0).doCanonicalize().equals(scans2(1).doCanonicalize())) + } + } class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { From 8b864e7921061958c43006335196898e6f3c4be8 Mon Sep 17 00:00:00 2001 From: mingjial Date: Thu, 10 Sep 2020 18:20:16 -0700 Subject: [PATCH 7/7] [SPARK-32708] Query optimization fails to reuse exchange with DataSourceV2 --- .../sql/sources/v2/DataSourceV2Suite.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index a09f61cc26a2e..92e0ac93db568 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -395,25 +395,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-32708: same columns with different ExprIds should be equal after canonicalization ") { - def getV2ScanExecs(query: DataFrame): Seq[DataSourceV2ScanExec] = { + def getV2ScanExec(query: DataFrame): DataSourceV2ScanExec = { query.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => d - } + }.head } val df1 = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() - val q1 = df1.select(($"i" + 1).as("k"), ($"i" - 1).as("j")).filter('i > 5) + val q1 = df1.select('i).filter('i > 6) val df2 = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() - val q2 = df2.select(($"i" + 1).as("k"), ($"i" - 1).as("j")).filter('i > 5) - - val scans1 = getV2ScanExecs(q1.join(q2, "j")) - assert(scans1(0).sameResult(scans1(1))) - assert(scans1(0).doCanonicalize().equals(scans1(1).doCanonicalize())) - - val q3 = df2.select(($"i" + 1).as("k"), ($"i" - 1).as("j")).filter('i > 6) - val scans2 = getV2ScanExecs(q1.join(q3, "j")) - assert(!scans2(0).sameResult(scans2(1))) - assert(!scans2(0).doCanonicalize().equals(scans2(1).doCanonicalize())) + val q2 = df2.select('i).filter('i > 6) + val scan1 = getV2ScanExec(q1) + val scan2 = getV2ScanExec(q2) + assert(scan1.sameResult(scan2)) + assert(scan1.doCanonicalize().equals(scan2.doCanonicalize())) + + val q3 = df2.select('i).filter('i > 5) + val scan3 = getV2ScanExec(q3) + assert(!scan1.sameResult(scan3)) + assert(!scan1.doCanonicalize().equals(scan3.doCanonicalize())) } }