From 42d6eda8c94bc629753fee5ccff61d41a11969db Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 9 Apr 2021 12:32:21 -0700 Subject: [PATCH] Fix nested pruning on array of struct. --- .../sql/execution/ProjectionOverSchema.scala | 9 +++- .../spark/sql/execution/SelectedField.scala | 26 +++++++----- .../parquet/ParquetSchemaPruningSuite.scala | 42 +++++++++++++++++++ 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala index 612a7b87b9832..de5ecb04511c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala @@ -40,9 +40,14 @@ private[execution] case class ProjectionOverSchema(schema: StructType) { case a: GetArrayStructFields => getProjection(a.child).map(p => (p, p.dataType)).map { case (projection, ArrayType(projSchema @ StructType(_), _)) => + // For case-sensitivity aware field resolution, we should take `ordinal` which + // points to correct struct field. + val selectedField = a.child.dataType.asInstanceOf[ArrayType] + .elementType.asInstanceOf[StructType](a.ordinal) + val prunedField = projSchema(selectedField.name) GetArrayStructFields(projection, - projSchema(a.field.name), - projSchema.fieldIndex(a.field.name), + prunedField.copy(name = a.field.name), + projSchema.fieldIndex(selectedField.name), projSchema.size, a.containsNull) case (_, projSchema) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala index 0e7c593f9fb67..23e894b2c123c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala @@ -81,19 +81,25 @@ private[execution] object SelectedField { case GetArrayItem(child, _) => selectField(child, fieldOpt) // Handles case "expr0.field.subfield", where "expr0" and "expr0.field" are of array type. - case GetArrayStructFields(child: GetArrayStructFields, - field @ StructField(name, dataType, nullable, metadata), _, _, _) => - val childField = fieldOpt.map(field => StructField(name, - wrapStructType(dataType, field), - nullable, metadata)).orElse(Some(field)) + case GetArrayStructFields(child: GetArrayStructFields, _, ordinal, _, _) => + // For case-sensitivity aware field resolution, we should take `ordinal` which + // points to correct struct field. + val selectedField = child.dataType.asInstanceOf[ArrayType] + .elementType.asInstanceOf[StructType](ordinal) + val childField = fieldOpt.map(field => StructField(selectedField.name, + wrapStructType(selectedField.dataType, field), + selectedField.nullable, selectedField.metadata)).orElse(Some(selectedField)) selectField(child, childField) // Handles case "expr0.field", where "expr0" is of array type. - case GetArrayStructFields(child, - field @ StructField(name, dataType, nullable, metadata), _, _, _) => + case GetArrayStructFields(child, _, ordinal, _, _) => + // For case-sensitivity aware field resolution, we should take `ordinal` which + // points to correct struct field. + val selectedField = child.dataType.asInstanceOf[ArrayType] + .elementType.asInstanceOf[StructType](ordinal) val childField = - fieldOpt.map(field => StructField(name, - wrapStructType(dataType, field), - nullable, metadata)).orElse(Some(field)) + fieldOpt.map(field => StructField(selectedField.name, + wrapStructType(selectedField.dataType, field), + selectedField.nullable, selectedField.metadata)).orElse(Some(selectedField)) selectField(child, childField) // Handles case "expr0.field[key]", where "expr0" is of struct type and "expr0.field" is of // map type. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala index 966190e12c6ba..95bc7ca963155 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala @@ -416,4 +416,46 @@ class ParquetSchemaPruningSuite assert(scanSchema === expectedScanSchema) } } + + testSchemaPruning("SPARK-34963: extract case-insensitive struct field from array") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val query1 = spark.table("contacts") + .select("friends.First", "friends.MiDDle") + checkScan(query1, "struct>>") + checkAnswer(query1, + Row(Array.empty[String], Array.empty[String]) :: + Row(Array("Susan"), Array("Z.")) :: + Row(null, null) :: + Row(null, null) :: Nil) + + val query2 = spark.table("contacts") + .where("friends.First is not null") + .select("friends.First", "friends.MiDDle") + checkScan(query2, "struct>>") + checkAnswer(query2, + Row(Array.empty[String], Array.empty[String]) :: + Row(Array("Susan"), Array("Z.")) :: Nil) + } + } + + testSchemaPruning("SPARK-34963: extract case-insensitive struct field from struct") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val query1 = spark.table("contacts") + .select("Name.First", "NAME.MiDDle") + checkScan(query1, "struct>") + checkAnswer(query1, + Row("Jane", "X.") :: + Row("Janet", null) :: + Row("Jim", null) :: + Row("John", "Y.") :: Nil) + + val query2 = spark.table("contacts") + .where("Name.MIDDLE is not null") + .select("Name.First", "NAME.MiDDle") + checkScan(query2, "struct>") + checkAnswer(query2, + Row("Jane", "X.") :: + Row("John", "Y.") :: Nil) + } + } }