diff --git a/native/core/src/execution/shuffle/row.rs b/native/core/src/execution/shuffle/row.rs index bb1401e263..c98cc54387 100644 --- a/native/core/src/execution/shuffle/row.rs +++ b/native/core/src/execution/shuffle/row.rs @@ -444,25 +444,18 @@ pub(crate) fn append_field( // Appending value into struct field builder of Arrow struct builder. let field_builder = struct_builder.field_builder::(idx).unwrap(); - if row.is_null_row() { - // The row is null. + let nested_row = if row.is_null_row() || row.is_null_at(idx) { + // The row is null, or the field in the row is null, i.e., a null nested row. + // Append a null value to the row builder. field_builder.append_null(); + SparkUnsafeRow::default() } else { - let is_null = row.is_null_at(idx); + field_builder.append(true); + row.get_struct(idx, fields.len()) + }; - let nested_row = if is_null { - // The field in the row is null, i.e., a null nested row. - // Append a null value to the row builder. - field_builder.append_null(); - SparkUnsafeRow::default() - } else { - field_builder.append(true); - row.get_struct(idx, fields.len()) - }; - - for (field_idx, field) in fields.into_iter().enumerate() { - append_field(field.data_type(), field_builder, &nested_row, field_idx)?; - } + for (field_idx, field) in fields.into_iter().enumerate() { + append_field(field.data_type(), field_builder, &nested_row, field_idx)?; } } DataType::Map(field, _) => { @@ -3302,3 +3295,45 @@ fn make_batch(arrays: Vec, row_count: usize) -> Result + val testData = "{}\n" + val path = Paths.get(dir.toString, "test.json") + Files.write(path, testData.getBytes) + + // Define the nested struct schema + val readSchema = StructType( + Array( + StructField( + "metaData", + StructType( + Array(StructField( + "format", + StructType(Array(StructField("provider", StringType, nullable = true))), + nullable = true))), + nullable = true))) + + // Read JSON with custom schema and repartition, this will repartition rows that contain + // null struct fields. + val df = spark.read.format("json").schema(readSchema).load(path.toString).repartition(2) + assert(df.count() == 1) + val row = df.collect()(0) + assert(row.getAs[org.apache.spark.sql.Row]("metaData") == null) + } + } + /** * Checks that `df` produces the same answer as Spark does, and has the `expectedNum` Comet * exchange operators.