From ce02f753b6bf6d6a811526c30595f7854737b020 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Fri, 15 Apr 2022 16:09:35 -0700 Subject: [PATCH 1/3] fix schema pruning --- .../expressions/ProjectionOverSchema.scala | 8 +++- .../sql/catalyst/optimizer/objects.scala | 2 +- .../execution/datasources/SchemaPruning.scala | 4 +- .../v2/V2ScanRelationPushDown.scala | 6 +-- .../datasources/SchemaPruningSuite.scala | 40 ++++++++++++++++++- 5 files changed, 51 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index a6be98c8a3aae..0b1765d5eea13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -24,15 +24,19 @@ import org.apache.spark.sql.types._ * field indexes and field counts of complex type extractors and attributes * are adjusted to fit the schema. All other expressions are left as-is. This * class is motivated by columnar nested schema pruning. + * + * @param schema nested column schema + * @param output output attributes of the data source relation. They are used to filter out + * attributes in the schema that do not belong to the current relation. */ -case class ProjectionOverSchema(schema: StructType) { +case class ProjectionOverSchema(schema: StructType, output: Option[AttributeSet] = None) { private val fieldNames = schema.fieldNames.toSet def unapply(expr: Expression): Option[Expression] = getProjection(expr) private def getProjection(expr: Expression): Option[Expression] = expr match { - case a: AttributeReference if fieldNames.contains(a.name) => + case a: AttributeReference if fieldNames.contains(a.name) && output.forall(_.contains(a)) => Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) case GetArrayItem(child, arrayItemOrdinal, failOnError) => getProjection(child).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 82aef32c5a22f..dee7ff5bac6ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -229,7 +229,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { } // Builds new projection. - val projectionOverSchema = ProjectionOverSchema(prunedSchema) + val projectionOverSchema = ProjectionOverSchema(prunedSchema, Some(AttributeSet(s.output))) val newProjects = p.projectList.map(_.transformDown { case projectionOverSchema(expr) => expr }).map { case expr: NamedExpression => expr } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index a49c10c852b08..699f8a15f489c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -91,8 +91,8 @@ object SchemaPruning extends Rule[LogicalPlan] { if (countLeaves(hadoopFsRelation.dataSchema) > countLeaves(prunedDataSchema) || countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema)) { val prunedRelation = leafNodeBuilder(prunedDataSchema, prunedMetadataSchema) - val projectionOverSchema = - ProjectionOverSchema(prunedDataSchema.merge(prunedMetadataSchema)) + val projectionOverSchema = ProjectionOverSchema( + prunedDataSchema.merge(prunedMetadataSchema), Some(AttributeSet(relation.output))) Some(buildNewProjection(projects, normalizedProjects, normalizedFilters, prunedRelation, projectionOverSchema)) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 6455e25089276..c9154aabf8766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} -import org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, AliasHelper, And, Attribute, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.ScanOperation @@ -320,7 +319,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - val projectionOverSchema = ProjectionOverSchema(output.toStructType) + val projectionOverSchema = + ProjectionOverSchema(output.toStructType, Some(AttributeSet(output))) val projectionFunc = (expr: Expression) => expr transformDown { case projectionOverSchema(newExpr) => newExpr } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 4eb8258830ce5..8cdaaae2feebb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -61,11 +61,15 @@ abstract class SchemaPruningSuite override protected def sparkConf: SparkConf = super.sparkConf.set(SQLConf.ANSI_STRICT_INDEX_OPERATOR.key, "false") + case class Employee(id: Int, name: FullName, employer: Company) + val janeDoe = FullName("Jane", "X.", "Doe") val johnDoe = FullName("John", "Y.", "Doe") val susanSmith = FullName("Susan", "Z.", "Smith") - val employer = Employer(0, Company("abc", "123 Business Street")) + val company = Company("abc", "123 Business Street") + + val employer = Employer(0, company) val employerWithNullCompany = Employer(1, null) val employerWithNullCompany2 = Employer(2, null) @@ -81,6 +85,8 @@ abstract class SchemaPruningSuite Department(1, "Marketing", 1, employerWithNullCompany) :: Department(2, "Operation", 4, employerWithNullCompany2) :: Nil + val employees = Employee(0, janeDoe, company) :: Employee(1, johnDoe, company) :: Nil + case class Name(first: String, last: String) case class BriefContact(id: Int, name: Name, address: String) @@ -621,6 +627,21 @@ abstract class SchemaPruningSuite } } + testSchemaPruning("SPARK-38918: nested schema pruning with correlated subqueries") { + withContacts { + withEmployees { + val query = sql( + """ + |select count(*) + |from contacts c + |where not exists (select null from employees e where e.name.first = c.name.first + | and e.employer.name = c.employer.company.name) + |""".stripMargin) + checkAnswer(query, Row(3)) + } + } + } + protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(vectorizedReaderEnabledKey -> "true") { @@ -701,6 +722,23 @@ abstract class SchemaPruningSuite } } + private def withEmployees(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeDataSourceFile(employees, new File(path + "/employees")) + + // Providing user specified schema. Inferred schema from different data sources might + // be different. + val schema = "`id` INT,`name` STRUCT<`first`: STRING, `middle`: STRING, `last`: STRING>, " + + "`employer` STRUCT<`name`: STRING, `address`: STRING>" + spark.read.format(dataSourceName).schema(schema).load(path + "/employees") + .createOrReplaceTempView("employees") + + testThunk + } + } + case class MixedCaseColumn(a: String, B: Int) case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn) From 4a251cb08a4a0a37ab2468fbbb0de26a5909c3c3 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Fri, 22 Apr 2022 16:30:29 -0700 Subject: [PATCH 2/3] address comments --- .../sql/catalyst/expressions/ProjectionOverSchema.scala | 5 ++--- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 1 + .../org/apache/spark/sql/catalyst/optimizer/objects.scala | 2 +- .../spark/sql/execution/datasources/SchemaPruning.scala | 2 +- .../execution/datasources/v2/V2ScanRelationPushDown.scala | 2 +- .../spark/sql/execution/datasources/SchemaPruningSuite.scala | 5 +++++ 6 files changed, 11 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 0b1765d5eea13..927f44837c097 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -29,14 +29,13 @@ import org.apache.spark.sql.types._ * @param output output attributes of the data source relation. They are used to filter out * attributes in the schema that do not belong to the current relation. */ -case class ProjectionOverSchema(schema: StructType, output: Option[AttributeSet] = None) { - private val fieldNames = schema.fieldNames.toSet +case class ProjectionOverSchema(schema: StructType, output: AttributeSet) { def unapply(expr: Expression): Option[Expression] = getProjection(expr) private def getProjection(expr: Expression): Option[Expression] = expr match { - case a: AttributeReference if fieldNames.contains(a.name) && output.forall(_.contains(a)) => + case a: AttributeReference if output.contains(a) => Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) case GetArrayItem(child, arrayItemOrdinal, failOnError) => getProjection(child).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bb788336c6d77..2dc775e8e6671 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -60,6 +60,7 @@ abstract class Optimizer(catalogManager: CatalogManager) override protected val excludedOnceBatches: Set[String] = Set( "PartitionPruning", + "RewriteSubquery", "Extract Python UDFs") protected def fixedPoint = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index dee7ff5bac6ca..3387bb2007703 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -229,7 +229,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { } // Builds new projection. - val projectionOverSchema = ProjectionOverSchema(prunedSchema, Some(AttributeSet(s.output))) + val projectionOverSchema = ProjectionOverSchema(prunedSchema, AttributeSet(s.output)) val newProjects = p.projectList.map(_.transformDown { case projectionOverSchema(expr) => expr }).map { case expr: NamedExpression => expr } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index 699f8a15f489c..26d5d92fecb3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -92,7 +92,7 @@ object SchemaPruning extends Rule[LogicalPlan] { countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema)) { val prunedRelation = leafNodeBuilder(prunedDataSchema, prunedMetadataSchema) val projectionOverSchema = ProjectionOverSchema( - prunedDataSchema.merge(prunedMetadataSchema), Some(AttributeSet(relation.output))) + prunedDataSchema.merge(prunedMetadataSchema), AttributeSet(relation.output)) Some(buildNewProjection(projects, normalizedProjects, normalizedFilters, prunedRelation, projectionOverSchema)) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index c9154aabf8766..b7e0531989f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -320,7 +320,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) val projectionOverSchema = - ProjectionOverSchema(output.toStructType, Some(AttributeSet(output))) + ProjectionOverSchema(output.toStructType, AttributeSet(output)) val projectionFunc = (expr: Expression) => expr transformDown { case projectionOverSchema(newExpr) => newExpr } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 8cdaaae2feebb..3c715ca602a43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -637,6 +637,11 @@ abstract class SchemaPruningSuite |where not exists (select null from employees e where e.name.first = c.name.first | and e.employer.name = c.employer.company.name) |""".stripMargin) + checkScan(query, + "struct," + + "employer:struct>>", + "struct," + + "employer:struct>") checkAnswer(query, Row(3)) } } From 0f7eaa6a8963de162a1207d390c10f1cf2604c86 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Tue, 26 Apr 2022 11:58:17 -0700 Subject: [PATCH 3/3] fix tests --- .../spark/sql/catalyst/expressions/ProjectionOverSchema.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 927f44837c097..69d30dd5048da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -30,12 +30,13 @@ import org.apache.spark.sql.types._ * attributes in the schema that do not belong to the current relation. */ case class ProjectionOverSchema(schema: StructType, output: AttributeSet) { + private val fieldNames = schema.fieldNames.toSet def unapply(expr: Expression): Option[Expression] = getProjection(expr) private def getProjection(expr: Expression): Option[Expression] = expr match { - case a: AttributeReference if output.contains(a) => + case a: AttributeReference if fieldNames.contains(a.name) && output.contains(a) => Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) case GetArrayItem(child, arrayItemOrdinal, failOnError) => getProjection(child).map {