diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1d2e48301ea98..45cf82be51ed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -371,6 +371,7 @@ class Analyzer( ResolveRandomSeed :: ResolveBinaryArithmetic :: ResolveUnion :: + ResolveWithFields :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithFields.scala new file mode 100644 index 0000000000000..710dca15ae9ae --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithFields.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.WithFields +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Resolves `UnresolvedWithFields`. + */ +object ResolveWithFields extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case e if !e.childrenResolved => e + + case q: LogicalPlan => + q.transformExpressions { + case expr if !expr.childrenResolved => expr + case e: UnresolvedWithFields => WithFields(e.col, e.fieldName, e.expr) + case others => others + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 62000ac0efbb3..332fb0aaef838 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -554,3 +554,14 @@ case class UnresolvedHaving( override lazy val resolved: Boolean = false override def output: Seq[Attribute] = child.output } + +case class UnresolvedWithFields( + col: Expression, + fieldName: String, + expr: Expression) extends Unevaluable with NonSQLExpression { + override def children: Seq[Expression] = Seq(col, expr) + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 563ce7133a3dc..1265e8a4b85e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -595,3 +597,68 @@ case class WithFields( } } } + +object WithFields { + /** + * Adds/replaces field in `StructType` into `col` expression by name. + */ + def apply( + col: Expression, + fieldName: String, + expr: Expression): Expression = { + val nameParts = if (fieldName.isEmpty) { + fieldName :: Nil + } else { + CatalystSqlParser.parseMultipartIdentifier(fieldName) + } + withFieldHelper(col, nameParts, expr) + } + + /** + * Recursively builds expressions for adding/replacing (nested) field in `StructType` into + * `col` expression by name. Supports nested struct in array and struct. + */ + private def withFieldHelper( + col: Expression, + namePartsRemaining: Seq[String], + value: Expression) : Expression = { + val name = namePartsRemaining.head + if (namePartsRemaining.length == 1) { + col.dataType match { + case ArrayType(et, containsNull) => + val lv = NamedLambdaVariable("arg", et, containsNull) + val function = withFieldHelper(lv, name :: Nil, value) + ArrayTransform(col, LambdaFunction(function, Seq(lv))) + + case _: StructType => + WithFields(col, name :: Nil, value :: Nil) + + case dt => + throw new AnalysisException(s"WithFields's argument does not support ${dt.catalogString}") + } + } else { + val newNamesRemaining = namePartsRemaining.tail + + col.dataType match { + case ArrayType(et, containsNull) => + val lv = NamedLambdaVariable("arg", et, containsNull) + val function = withFieldHelper(lv, namePartsRemaining, value) + ArrayTransform(col, LambdaFunction(function, Seq(lv))) + + case _: StructType => + val resolver = SQLConf.get.resolver + val newCol = ExtractValue(col, Literal(name), resolver) + val newValue = withFieldHelper( + col = newCol, + namePartsRemaining = newNamesRemaining, + value = value) + + WithFields(col, name :: Nil, newValue :: Nil) + + case dt => + throw new AnalysisException(s"WithFields's argument does not support ${dt.catalogString}") + } + + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index da542c67d9c51..07a13848b8854 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -909,31 +909,10 @@ class Column(val expr: Expression) extends Logging { require(fieldName != null, "fieldName cannot be null") require(col != null, "col cannot be null") - val nameParts = if (fieldName.isEmpty) { - fieldName :: Nil + if (expr.resolved) { + WithFields(expr, fieldName, col.expr) } else { - CatalystSqlParser.parseMultipartIdentifier(fieldName) - } - withFieldHelper(expr, nameParts, Nil, col.expr) - } - - private def withFieldHelper( - struct: Expression, - namePartsRemaining: Seq[String], - namePartsDone: Seq[String], - value: Expression) : WithFields = { - val name = namePartsRemaining.head - if (namePartsRemaining.length == 1) { - WithFields(struct, name :: Nil, value :: Nil) - } else { - val newNamesRemaining = namePartsRemaining.tail - val newNamesDone = namePartsDone :+ name - val newValue = withFieldHelper( - struct = UnresolvedExtractValue(struct, Literal(name)), - namePartsRemaining = newNamesRemaining, - namePartsDone = newNamesDone, - value = value) - WithFields(struct, name :: Nil, newValue :: Nil) + UnresolvedWithFields(expr, fieldName, col.expr) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 24419968c0472..faacdcaa6c8f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -967,10 +967,49 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false))), nullable = false)))) + + private lazy val arrayType = ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false))), + containsNull = true) + + private lazy val arrayLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Array(Row(1, null, 3))) :: Nil), + StructType(Seq(StructField("a", arrayType, nullable = false)))) + + private lazy val nullArrayLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Array(null)) :: Nil), + StructType(Seq(StructField("a", arrayType, nullable = true)))) + + private lazy val arrayLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Array(Row(1, null, 3)))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", arrayType, nullable = false))), + nullable = false)))) + + private lazy val nullArrayLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Array(null))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", arrayType, nullable = false))), + nullable = false)))) + + private lazy val arrayLevel3: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(Array(Row(1, null, 3))))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", arrayType, nullable = false))), + nullable = false))), + nullable = false)))) + test("withField should throw an exception if called on a non-StructType column") { intercept[AnalysisException] { testData.withColumn("key", $"key".withField("a", lit(2))) - }.getMessage should include("struct argument should be struct type, got: int") + }.getMessage should include("WithFields's argument does not support int") } test("withField should throw an exception if either fieldName or col argument are null") { @@ -1000,7 +1039,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should throw an exception if intermediate field is not a struct") { intercept[AnalysisException] { structLevel1.withColumn("a", 'a.withField("b.a", lit(2))) - }.getMessage should include("struct argument should be struct type, got: int") + }.getMessage should include("WithFields's argument does not support int") } test("withField should throw an exception if intermediate field reference is ambiguous") { @@ -1537,4 +1576,303 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { Row(3) :: Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) } + + test("withField should add field to struct of array") { + checkAnswerAndSchema( + arrayLevel1.withColumn("a", 'a.withField("d", lit(4))), + Row(Array(Row(1, null, 3, 4))) :: Nil, + StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + containsNull = true), nullable = false)))) + } + + test("withField should add multiple fields to struct of array") { + checkAnswerAndSchema( + arrayLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), + Row(Array(Row(1, null, 3, 4, 5))) :: Nil, + StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false), + StructField("e", IntegerType, nullable = false))), + containsNull = true), nullable = false)))) + } + + test("withField should add field to nested struct of array") { + Seq( + arrayLevel2.withColumn("a", 'a.withField("a.d", lit(4))), + arrayLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d", lit(4)))) + ).foreach { df => + checkAnswerAndSchema( + df, + Row(Row(Array(Row(1, null, 3, 4)))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + containsNull = true), + nullable = false))), + nullable = false)))) + } + } + + test("withField should add field to deeply nested struct of array") { + checkAnswerAndSchema( + arrayLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))), + Row(Row(Row(Array(Row(1, null, 3, 4))))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + containsNull = true), nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("withField should add field to null struct of array") { + checkAnswerAndSchema( + nullArrayLevel1.withColumn("a", $"a".withField("d", lit(4))), + Row(Array(null)) :: Nil, + StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + containsNull = true))))) + } + + test("withField should add field to nested null struct of array") { + checkAnswerAndSchema( + nullArrayLevel2.withColumn("a", $"a".withField("a.d", lit(4))), + Row(Row(Array(null))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + containsNull = true), + nullable = false))), + nullable = false)))) + } + + + test("withField should replace field in struct of array") { + checkAnswerAndSchema( + arrayLevel1.withColumn("a", 'a.withField("b", lit(2))), + Row(Array(Row(1, 2, 3))) :: Nil, + StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + containsNull = true), nullable = false)))) + } + + test("withField should replace field in null struct of array") { + checkAnswerAndSchema( + nullArrayLevel1.withColumn("a", 'a.withField("b", lit("foo"))), + Row(Array(null)) :: Nil, + StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", IntegerType, nullable = false))), + containsNull = true), nullable = true)))) + } + + test("withField should replace field in nested null struct of array") { + checkAnswerAndSchema( + nullArrayLevel2.withColumn("a", $"a".withField("a.b", lit("foo"))), + Row(Row(Array(null))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", IntegerType, nullable = false))), + containsNull = true), nullable = false))), + nullable = false)))) + } + + test("withField should replace field with null value in struct of array") { + checkAnswerAndSchema( + arrayLevel1.withColumn("a", 'a.withField("c", lit(null).cast(IntegerType))), + Row(Array(Row(1, null, null))) :: Nil, + StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true))), + containsNull = true), nullable = false)))) + } + + test("withField should replace multiple fields in struct of array") { + checkAnswerAndSchema( + arrayLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), + Row(Array(Row(10, 20, 3))) :: Nil, + StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + containsNull = true), nullable = false)))) + } + + test("withField should replace field in nested struct of array") { + Seq( + arrayLevel2.withColumn("a", $"a".withField("a.b", lit(2))), + arrayLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b", lit(2)))) + ).foreach { df => + checkAnswerAndSchema( + df, + Row(Row(Array(Row(1, 2, 3)))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + containsNull = true), nullable = false))), + nullable = false)))) + } + } + + test("withField should replace field in deeply nested struct of array") { + checkAnswerAndSchema( + arrayLevel3.withColumn("a", $"a".withField("a.a.b", lit(2))), + Row(Row(Row(Array(Row(1, 2, 3))))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", ArrayType(StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + containsNull = true), nullable = false))), + nullable = false))), + nullable = false)))) + } + + + test("withField should replace all fields with given name in struct of array") { + val arrayLevel1 = spark.createDataFrame( + sparkContext.parallelize(Row(Array(Row(1, 2, 3))) :: Nil), + StructType(Seq( + StructField("a", ArrayType(StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + containsNull = false), nullable = false)))) + + checkAnswerAndSchema( + arrayLevel1.withColumn("a", 'a.withField("b", lit(100))), + Row(Array(Row(1, 100, 100))) :: Nil, + StructType(Seq( + StructField("a", ArrayType(StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + containsNull = false), nullable = false)))) + } + + private lazy val arrayStructArrayLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Array(Row(Array(Row(1, null, 3)), null, 3))) :: Nil), + StructType( + Seq(StructField("a", ArrayType( + StructType(Seq( + StructField("a", arrayType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false))), + containsNull = false))))) + + test("withField should add and replace field to struct of array of array") { + checkAnswerAndSchema( + arrayStructArrayLevel1.withColumn("a", $"a".withField("a.d", lit(2))), + Row(Seq(Row(Seq(Row(1, null, 3, 2)), null, 3))) :: Nil, + StructType( + Seq(StructField("a", ArrayType( + StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + containsNull = true), nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false))), + containsNull = false))))) + + checkAnswerAndSchema( + arrayStructArrayLevel1.withColumn("a", $"a.a".withField("d", lit(2))), + Row(Seq(Seq(Row(1, null, 3, 2)))) :: Nil, + StructType( + Seq(StructField("a", ArrayType( + ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + containsNull = true), + containsNull = false))))) + + checkAnswerAndSchema( + arrayStructArrayLevel1.withColumn("a", $"a".withField("a.b", lit(2))), + Row(Seq(Row(Seq(Row(1, 2, 3)), null, 3))) :: Nil, + StructType( + Seq(StructField("a", ArrayType( + StructType(Seq( + StructField("a", ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + containsNull = true), nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false))), + containsNull = false))))) + + checkAnswerAndSchema( + arrayStructArrayLevel1.withColumn("a", $"a.a".withField("b", lit(2))), + Row(Seq(Seq(Row(1, 2, 3)))) :: Nil, + StructType( + Seq(StructField("a", ArrayType( + ArrayType( + StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + containsNull = true), + containsNull = false))))) + + } }