From f4db7e94c1a13c545dba9ee267a3e55946830010 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 5 Apr 2021 14:18:33 -0700 Subject: [PATCH 1/5] Fix nested pruning on array of struct. --- .../expressions/ProjectionOverSchema.scala | 7 +++++-- .../catalyst/expressions/SchemaPruning.scala | 21 +++++++++++++------ .../datasources/SchemaPruningSuite.scala | 13 ++++++++++++ 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 241c761624b76..3dad2bc7f5894 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -27,6 +28,7 @@ import org.apache.spark.sql.types._ */ case class ProjectionOverSchema(schema: StructType) { private val fieldNames = schema.fieldNames.toSet + private val resolver = SQLConf.get.resolver def unapply(expr: Expression): Option[Expression] = getProjection(expr) @@ -41,9 +43,10 @@ case class ProjectionOverSchema(schema: StructType) { case a: GetArrayStructFields => getProjection(a.child).map(p => (p, p.dataType)).map { case (projection, ArrayType(projSchema @ StructType(_), _)) => + val selectedField = projSchema.find(f => resolver(f.name, a.field.name)).get GetArrayStructFields(projection, - projSchema(a.field.name), - projSchema.fieldIndex(a.field.name), + selectedField, + projSchema.fieldIndex(selectedField.name), projSchema.size, a.containsNull) case (_, projSchema) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala index 6213267c41c64..bbd00e2b9502d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ object SchemaPruning { @@ -48,7 +49,8 @@ object SchemaPruning { * right, recursively. That is, left is a "subschema" of right, ignoring order of * fields. */ - private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType = + private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType = { + val resolver = SQLConf.get.resolver (left, right) match { case (ArrayType(leftElementType, containsNull), ArrayType(rightElementType, _)) => ArrayType( @@ -61,16 +63,23 @@ object SchemaPruning { sortLeftFieldsByRight(leftValueType, rightValueType), containsNull) case (leftStruct: StructType, rightStruct: StructType) => - val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains) - val sortedLeftFields = filteredRightFieldNames.map { fieldName => - val leftFieldType = leftStruct(fieldName).dataType - val rightFieldType = rightStruct(fieldName).dataType + val filteredRightFieldNames = rightStruct.fieldNames.filter { rightField => + leftStruct.fieldNames.exists(resolver(_, rightField)) + } + val matchedFields = filteredRightFieldNames.map { rightField => + (leftStruct.fieldNames.find(resolver(_, rightField)).get, rightField) + } + val sortedLeftFields = matchedFields.map { case (leftFieldName, rightFieldName) => + val leftFieldType = leftStruct(leftFieldName).dataType + val rightFieldType = rightStruct(rightFieldName).dataType val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType) - StructField(fieldName, sortedLeftFieldType, nullable = leftStruct(fieldName).nullable) + StructField(rightFieldName, sortedLeftFieldType, + nullable = leftStruct(leftFieldName).nullable) } StructType(sortedLeftFields) case _ => left } + } /** * Returns the set of fields from projection and filtering predicates that the query plan needs. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index c90732183cb7a..7684421f753d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -774,4 +774,17 @@ abstract class SchemaPruningSuite assert(scanSchema === expectedScanSchema) } } + + testSchemaPruning("extract case-insensitive struct field from array") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val query = spark.table("contacts") + .select("friends.First", "friends.MiDDle") + checkScan(query, "struct>>") + checkAnswer(query, + Row(Array.empty[String], Array.empty[String]) :: + Row(Array("Susan"), Array("Z.")) :: + Row(null, null) :: + Row(null, null) :: Nil) + } + } } From c33530448016aa3ff7e785b9613a23bb89c93059 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 6 Apr 2021 20:13:24 -0700 Subject: [PATCH 2/5] Use correctly resolved ordinal from GetArrayStructFields. --- .../expressions/ProjectionOverSchema.scala | 7 ++++--- .../catalyst/expressions/SchemaPruning.scala | 18 +++++------------- .../catalyst/expressions/SelectedField.scala | 6 +++++- .../datasources/SchemaPruningSuite.scala | 15 ++++++++++++++- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 3dad2bc7f5894..2574bbf3e4e56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -28,7 +27,6 @@ import org.apache.spark.sql.types._ */ case class ProjectionOverSchema(schema: StructType) { private val fieldNames = schema.fieldNames.toSet - private val resolver = SQLConf.get.resolver def unapply(expr: Expression): Option[Expression] = getProjection(expr) @@ -43,7 +41,10 @@ case class ProjectionOverSchema(schema: StructType) { case a: GetArrayStructFields => getProjection(a.child).map(p => (p, p.dataType)).map { case (projection, ArrayType(projSchema @ StructType(_), _)) => - val selectedField = projSchema.find(f => resolver(f.name, a.field.name)).get + // For case-sensitivity aware field resolution, we should take `ordinal` which + // points to correct struct field. + val selectedField = a.child.dataType.asInstanceOf[ArrayType] + .elementType.asInstanceOf[StructType](a.ordinal) GetArrayStructFields(projection, selectedField, projSchema.fieldIndex(selectedField.name), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala index bbd00e2b9502d..175814c9e0faf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ object SchemaPruning { @@ -50,7 +49,6 @@ object SchemaPruning { * fields. */ private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType = { - val resolver = SQLConf.get.resolver (left, right) match { case (ArrayType(leftElementType, containsNull), ArrayType(rightElementType, _)) => ArrayType( @@ -63,18 +61,12 @@ object SchemaPruning { sortLeftFieldsByRight(leftValueType, rightValueType), containsNull) case (leftStruct: StructType, rightStruct: StructType) => - val filteredRightFieldNames = rightStruct.fieldNames.filter { rightField => - leftStruct.fieldNames.exists(resolver(_, rightField)) - } - val matchedFields = filteredRightFieldNames.map { rightField => - (leftStruct.fieldNames.find(resolver(_, rightField)).get, rightField) - } - val sortedLeftFields = matchedFields.map { case (leftFieldName, rightFieldName) => - val leftFieldType = leftStruct(leftFieldName).dataType - val rightFieldType = rightStruct(rightFieldName).dataType + val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains) + val sortedLeftFields = filteredRightFieldNames.map { fieldName => + val leftFieldType = leftStruct(fieldName).dataType + val rightFieldType = rightStruct(fieldName).dataType val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType) - StructField(rightFieldName, sortedLeftFieldType, - nullable = leftStruct(leftFieldName).nullable) + StructField(fieldName, sortedLeftFieldType, nullable = leftStruct(fieldName).nullable) } StructType(sortedLeftFields) case _ => left diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala index a5a42e540151d..4314aad9b46f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala @@ -75,7 +75,11 @@ object SelectedField { val field = c.childSchema(c.ordinal) val newField = field.copy(dataType = dataTypeOpt.getOrElse(field.dataType)) selectField(c.child, Option(struct(newField))) - case GetArrayStructFields(child, field, _, _, containsNull) => + case GetArrayStructFields(child, _, ordinal, _, containsNull) => + // For case-sensitivity aware field resolution, we should take `ordinal` which + // points to correct struct field. + val field = child.dataType.asInstanceOf[ArrayType] + .elementType.asInstanceOf[StructType](ordinal) val newFieldDataType = dataTypeOpt match { case None => // GetArrayStructFields is the top level extractor. This means its result is diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 7684421f753d4..6a52fc36e04a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -775,7 +775,7 @@ abstract class SchemaPruningSuite } } - testSchemaPruning("extract case-insensitive struct field from array") { + testSchemaPruning("SPARK-34963: extract case-insensitive struct field from array") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { val query = spark.table("contacts") .select("friends.First", "friends.MiDDle") @@ -787,4 +787,17 @@ abstract class SchemaPruningSuite Row(null, null) :: Nil) } } + + testSchemaPruning("SPARK-34963: extract case-insensitive struct field from struct") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val query = spark.table("contacts") + .select("Name.First", "NAME.MiDDle") + checkScan(query, "struct>") + checkAnswer(query, + Row("Jane", "X.") :: + Row("Janet", null) :: + Row("Jim", null) :: + Row("John", "Y.") :: Nil) + } + } } From ea17366a9d3cddebdfa671d38ed1e46e755de668 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 8 Apr 2021 10:15:48 -0700 Subject: [PATCH 3/5] Fix. --- .../spark/sql/catalyst/expressions/ProjectionOverSchema.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 2574bbf3e4e56..03b5517f6df05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -45,8 +45,9 @@ case class ProjectionOverSchema(schema: StructType) { // points to correct struct field. val selectedField = a.child.dataType.asInstanceOf[ArrayType] .elementType.asInstanceOf[StructType](a.ordinal) + val prunedField = projSchema(selectedField.name) GetArrayStructFields(projection, - selectedField, + prunedField.copy(name = a.field.name), projSchema.fieldIndex(selectedField.name), projSchema.size, a.containsNull) From 97a4784cabc49e943a549eb7bbe48d955831db01 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 8 Apr 2021 10:34:14 -0700 Subject: [PATCH 4/5] Add more test cases. --- .../datasources/SchemaPruningSuite.scala | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 6a52fc36e04a7..765d2fc584a7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -777,27 +777,43 @@ abstract class SchemaPruningSuite testSchemaPruning("SPARK-34963: extract case-insensitive struct field from array") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - val query = spark.table("contacts") + val query1 = spark.table("contacts") .select("friends.First", "friends.MiDDle") - checkScan(query, "struct>>") - checkAnswer(query, + checkScan(query1, "struct>>") + checkAnswer(query1, Row(Array.empty[String], Array.empty[String]) :: Row(Array("Susan"), Array("Z.")) :: Row(null, null) :: Row(null, null) :: Nil) + + val query2 = spark.table("contacts") + .where("friends.First is not null") + .select("friends.First", "friends.MiDDle") + checkScan(query2, "struct>>") + checkAnswer(query2, + Row(Array.empty[String], Array.empty[String]) :: + Row(Array("Susan"), Array("Z.")) :: Nil) } } testSchemaPruning("SPARK-34963: extract case-insensitive struct field from struct") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - val query = spark.table("contacts") + val query1 = spark.table("contacts") .select("Name.First", "NAME.MiDDle") - checkScan(query, "struct>") - checkAnswer(query, + checkScan(query1, "struct>") + checkAnswer(query1, Row("Jane", "X.") :: Row("Janet", null) :: Row("Jim", null) :: Row("John", "Y.") :: Nil) + + val query2 = spark.table("contacts") + .where("Name.MIDDLE is not null") + .select("Name.First", "NAME.MiDDle") + checkScan(query2, "struct>") + checkAnswer(query2, + Row("Jane", "X.") :: + Row("John", "Y.") :: Nil) } } } From 9005055df2f971a7d6ee909baaa0366c5e3b6683 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 8 Apr 2021 10:42:50 -0700 Subject: [PATCH 5/5] Remove unnecessary change. --- .../apache/spark/sql/catalyst/expressions/SchemaPruning.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala index 175814c9e0faf..6213267c41c64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala @@ -48,7 +48,7 @@ object SchemaPruning { * right, recursively. That is, left is a "subschema" of right, ignoring order of * fields. */ - private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType = { + private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType = (left, right) match { case (ArrayType(leftElementType, containsNull), ArrayType(rightElementType, _)) => ArrayType( @@ -71,7 +71,6 @@ object SchemaPruning { StructType(sortedLeftFields) case _ => left } - } /** * Returns the set of fields from projection and filtering predicates that the query plan needs.