diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc b/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc index cc36c510363..d91bf032e58 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc @@ -153,27 +153,32 @@ void AddListCast(CastFunction* func) { struct CastStruct { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const CastOptions& options = CastState::Get(ctx); - const StructType& in_type = checked_cast(*batch[0].type()); - const StructType& out_type = checked_cast(*out->type()); - const auto in_field_count = in_type.num_fields(); - - if (in_field_count != out_type.num_fields()) { - return Status::TypeError("struct field sizes do not match: ", in_type.ToString(), - " ", out_type.ToString()); - } - - for (int i = 0; i < in_field_count; ++i) { - const auto in_field = in_type.field(i); - const auto out_field = out_type.field(i); - if (in_field->name() != out_field->name()) { - return Status::TypeError("struct field names do not match: ", in_type.ToString(), - " ", out_type.ToString()); + const auto& in_type = checked_cast(*batch[0].type()); + const auto& out_type = checked_cast(*out->type()); + const int in_field_count = in_type.num_fields(); + const int out_field_count = out_type.num_fields(); + + std::vector fields_to_select(out_field_count, -1); + + int out_field_index = 0; + for (int in_field_index = 0; + in_field_index < in_field_count && out_field_index < out_field_count; + ++in_field_index) { + const auto& in_field = in_type.field(in_field_index); + const auto& out_field = out_type.field(out_field_index); + if (in_field->name() == out_field->name()) { + if (in_field->nullable() && !out_field->nullable()) { + return Status::TypeError("cannot cast nullable field to non-nullable field: ", + in_type.ToString(), " ", out_type.ToString()); + } + fields_to_select[out_field_index++] = in_field_index; } + } - if (in_field->nullable() && !out_field->nullable()) { - return Status::TypeError("cannot cast nullable struct to non-nullable struct: ", - in_type.ToString(), " ", out_type.ToString()); - } + if (out_field_index < out_field_count) { + return Status::TypeError( + "struct fields don't match or are in the wrong order: Input fields: ", + in_type.ToString(), " output fields: ", out_type.ToString()); } if (out->kind() == Datum::SCALAR) { @@ -182,9 +187,10 @@ struct CastStruct { DCHECK(!out_scalar->is_valid); if (in_scalar.is_valid) { - for (int i = 0; i < in_field_count; i++) { - auto values = in_scalar.value[i]; - auto target_type = out->type()->field(i)->type(); + out_field_index = 0; + for (int field_index : fields_to_select) { + const auto& values = in_scalar.value[field_index]; + const auto& target_type = out->type()->field(out_field_index++)->type(); ARROW_ASSIGN_OR_RAISE(Datum cast_values, Cast(values, target_type, options, ctx->exec_context())); DCHECK_EQ(Datum::SCALAR, cast_values.kind()); @@ -204,9 +210,11 @@ struct CastStruct { in_array.offset, in_array.length)); } - for (int i = 0; i < in_field_count; ++i) { - auto values = in_array.child_data[i]->Slice(in_array.offset, in_array.length); - auto target_type = out->type()->field(i)->type(); + out_field_index = 0; + for (int field_index : fields_to_select) { + const auto& values = + in_array.child_data[field_index]->Slice(in_array.offset, in_array.length); + const auto& target_type = out->type()->field(out_field_index++)->type(); ARROW_ASSIGN_OR_RAISE(Datum cast_values, Cast(values, target_type, options, ctx->exec_context())); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index f13b05ccd07..862e54d94bd 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -2245,8 +2245,223 @@ static void CheckStructToStruct( } } -TEST(Cast, StructToSameSizedAndNamedStruct) { - CheckStructToStruct({int32(), float32(), int64()}); +static void CheckStructToStructSubset( + const std::vector>& value_types) { + for (const auto& src_value_type : value_types) { + ARROW_SCOPED_TRACE("From type: ", src_value_type->ToString()); + for (const auto& dest_value_type : value_types) { + ARROW_SCOPED_TRACE("To type: ", dest_value_type->ToString()); + + std::vector field_names = {"a", "b", "c", "d", "e"}; + + std::shared_ptr a1, b1, c1, d1, e1; + a1 = ArrayFromJSON(src_value_type, "[1, 2, 5]"); + b1 = ArrayFromJSON(src_value_type, "[3, 4, 7]"); + c1 = ArrayFromJSON(src_value_type, "[9, 11, 44]"); + d1 = ArrayFromJSON(src_value_type, "[6, 51, 49]"); + e1 = ArrayFromJSON(src_value_type, "[19, 17, 74]"); + + std::shared_ptr a2, b2, c2, d2, e2; + a2 = ArrayFromJSON(dest_value_type, "[1, 2, 5]"); + b2 = ArrayFromJSON(dest_value_type, "[3, 4, 7]"); + c2 = ArrayFromJSON(dest_value_type, "[9, 11, 44]"); + d2 = ArrayFromJSON(dest_value_type, "[6, 51, 49]"); + e2 = ArrayFromJSON(dest_value_type, "[19, 17, 74]"); + + ASSERT_OK_AND_ASSIGN(auto src, + StructArray::Make({a1, b1, c1, d1, e1}, field_names)); + ASSERT_OK_AND_ASSIGN(auto dest1, + StructArray::Make({a2}, std::vector{"a"})); + CheckCast(src, dest1); + + ASSERT_OK_AND_ASSIGN( + auto dest2, StructArray::Make({b2, c2}, std::vector{"b", "c"})); + CheckCast(src, dest2); + + ASSERT_OK_AND_ASSIGN( + auto dest3, + StructArray::Make({c2, d2, e2}, std::vector{"c", "d", "e"})); + CheckCast(src, dest3); + + ASSERT_OK_AND_ASSIGN( + auto dest4, StructArray::Make({a2, b2, c2, e2}, + std::vector{"a", "b", "c", "e"})); + CheckCast(src, dest4); + + ASSERT_OK_AND_ASSIGN( + auto dest5, StructArray::Make({a2, b2, c2, d2, e2}, {"a", "b", "c", "d", "e"})); + CheckCast(src, dest5); + + // field does not exist + const auto dest6 = arrow::struct_({std::make_shared("a", int8()), + std::make_shared("d", int16()), + std::make_shared("f", int64())}); + const auto options6 = CastOptions::Safe(dest6); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("struct fields don't match or are in the wrong order"), + Cast(src, options6)); + + // fields in wrong order + const auto dest7 = arrow::struct_({std::make_shared("a", int8()), + std::make_shared("c", int16()), + std::make_shared("b", int64())}); + const auto options7 = CastOptions::Safe(dest7); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("struct fields don't match or are in the wrong order"), + Cast(src, options7)); + + // duplicate missing field names + const auto dest8 = arrow::struct_( + {std::make_shared("a", int8()), std::make_shared("c", int16()), + std::make_shared("d", int32()), std::make_shared("a", int64())}); + const auto options8 = CastOptions::Safe(dest8); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("struct fields don't match or are in the wrong order"), + Cast(src, options8)); + + // duplicate present field names + ASSERT_OK_AND_ASSIGN( + auto src_duplicate_field_names, + StructArray::Make({a1, b1, c1}, std::vector{"a", "a", "a"})); + + ASSERT_OK_AND_ASSIGN(auto dest1_duplicate_field_names, + StructArray::Make({a2}, std::vector{"a"})); + CheckCast(src_duplicate_field_names, dest1_duplicate_field_names); + + ASSERT_OK_AND_ASSIGN( + auto dest2_duplicate_field_names, + StructArray::Make({a2, b2}, std::vector{"a", "a"})); + CheckCast(src_duplicate_field_names, dest2_duplicate_field_names); + + ASSERT_OK_AND_ASSIGN( + auto dest3_duplicate_field_names, + StructArray::Make({a2, b2, c2}, std::vector{"a", "a", "a"})); + CheckCast(src_duplicate_field_names, dest3_duplicate_field_names); + } + } +} + +static void CheckStructToStructSubsetWithNulls( + const std::vector>& value_types) { + for (const auto& src_value_type : value_types) { + ARROW_SCOPED_TRACE("From type: ", src_value_type->ToString()); + for (const auto& dest_value_type : value_types) { + ARROW_SCOPED_TRACE("To type: ", dest_value_type->ToString()); + + std::vector field_names = {"a", "b", "c", "d", "e"}; + + std::shared_ptr a1, b1, c1, d1, e1; + a1 = ArrayFromJSON(src_value_type, "[1, 2, 5]"); + b1 = ArrayFromJSON(src_value_type, "[3, null, 7]"); + c1 = ArrayFromJSON(src_value_type, "[9, 11, 44]"); + d1 = ArrayFromJSON(src_value_type, "[6, 51, null]"); + e1 = ArrayFromJSON(src_value_type, "[null, 17, 74]"); + + std::shared_ptr a2, b2, c2, d2, e2; + a2 = ArrayFromJSON(dest_value_type, "[1, 2, 5]"); + b2 = ArrayFromJSON(dest_value_type, "[3, null, 7]"); + c2 = ArrayFromJSON(dest_value_type, "[9, 11, 44]"); + d2 = ArrayFromJSON(dest_value_type, "[6, 51, null]"); + e2 = ArrayFromJSON(dest_value_type, "[null, 17, 74]"); + + std::shared_ptr null_bitmap; + BitmapFromVector({0, 1, 0}, &null_bitmap); + + ASSERT_OK_AND_ASSIGN(auto src_null, StructArray::Make({a1, b1, c1, d1, e1}, + field_names, null_bitmap)); + ASSERT_OK_AND_ASSIGN( + auto dest1_null, + StructArray::Make({a2}, std::vector{"a"}, null_bitmap)); + CheckCast(src_null, dest1_null); + + ASSERT_OK_AND_ASSIGN( + auto dest2_null, + StructArray::Make({b2, c2}, std::vector{"b", "c"}, null_bitmap)); + CheckCast(src_null, dest2_null); + + ASSERT_OK_AND_ASSIGN( + auto dest3_null, + StructArray::Make({a2, d2, e2}, std::vector{"a", "d", "e"}, + null_bitmap)); + CheckCast(src_null, dest3_null); + + ASSERT_OK_AND_ASSIGN( + auto dest4_null, + StructArray::Make({a2, b2, c2, e2}, + std::vector{"a", "b", "c", "e"}, null_bitmap)); + CheckCast(src_null, dest4_null); + + ASSERT_OK_AND_ASSIGN( + auto dest5_null, + StructArray::Make({a2, b2, c2, d2, e2}, + std::vector{"a", "b", "c", "d", "e"}, + null_bitmap)); + CheckCast(src_null, dest5_null); + + // field does not exist + const auto dest6_null = arrow::struct_({std::make_shared("a", int8()), + std::make_shared("d", int16()), + std::make_shared("f", int64())}); + const auto options6_null = CastOptions::Safe(dest6_null); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("struct fields don't match or are in the wrong order"), + Cast(src_null, options6_null)); + + // fields in wrong order + const auto dest7_null = arrow::struct_({std::make_shared("a", int8()), + std::make_shared("c", int16()), + std::make_shared("b", int64())}); + const auto options7_null = CastOptions::Safe(dest7_null); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("struct fields don't match or are in the wrong order"), + Cast(src_null, options7_null)); + + // duplicate missing field names + const auto dest8_null = arrow::struct_( + {std::make_shared("a", int8()), std::make_shared("c", int16()), + std::make_shared("d", int32()), std::make_shared("a", int64())}); + const auto options8_null = CastOptions::Safe(dest8_null); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("struct fields don't match or are in the wrong order"), + Cast(src_null, options8_null)); + + // duplicate present field values + ASSERT_OK_AND_ASSIGN( + auto src_duplicate_field_names_null, + StructArray::Make({a1, b1, c1}, std::vector{"a", "a", "a"}, + null_bitmap)); + + ASSERT_OK_AND_ASSIGN( + auto dest1_duplicate_field_names_null, + StructArray::Make({a2}, std::vector{"a"}, null_bitmap)); + CheckCast(src_duplicate_field_names_null, dest1_duplicate_field_names_null); + + ASSERT_OK_AND_ASSIGN( + auto dest2_duplicate_field_names_null, + StructArray::Make({a2, b2}, std::vector{"a", "a"}, null_bitmap)); + CheckCast(src_duplicate_field_names_null, dest2_duplicate_field_names_null); + + ASSERT_OK_AND_ASSIGN( + auto dest3_duplicate_field_names_null, + StructArray::Make({a2, b2, c2}, std::vector{"a", "a", "a"}, + null_bitmap)); + CheckCast(src_duplicate_field_names_null, dest3_duplicate_field_names_null); + } + } +} + +TEST(Cast, StructToSameSizedAndNamedStruct) { CheckStructToStruct(NumericTypes()); } + +TEST(Cast, StructToStructSubset) { CheckStructToStructSubset(NumericTypes()); } + +TEST(Cast, StructToStructSubsetWithNulls) { + CheckStructToStructSubsetWithNulls(NumericTypes()); } TEST(Cast, StructToSameSizedButDifferentNamedStruct) { @@ -2262,12 +2477,11 @@ TEST(Cast, StructToSameSizedButDifferentNamedStruct) { EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, - ::testing::HasSubstr("Type error: struct field names do not match: struct struct"), + ::testing::HasSubstr("struct fields don't match or are in the wrong order"), Cast(src, options)); } -TEST(Cast, StructToDifferentSizeStruct) { +TEST(Cast, StructToBiggerStruct) { std::vector field_names = {"a", "b"}; std::shared_ptr a, b; a = ArrayFromJSON(int8(), "[1, 2]"); @@ -2281,52 +2495,100 @@ TEST(Cast, StructToDifferentSizeStruct) { EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, - ::testing::HasSubstr("Type error: struct field sizes do not match: struct struct"), + ::testing::HasSubstr("struct fields don't match or are in the wrong order"), Cast(src, options)); } -TEST(Cast, StructToSameSizedButDifferentNullabilityStruct) { - // OK to go from not-nullable to nullable... - std::vector> fields1 = { - std::make_shared("a", int8(), false), - std::make_shared("b", int8(), false)}; - std::shared_ptr a1, b1; - a1 = ArrayFromJSON(int8(), "[1, 2]"); - b1 = ArrayFromJSON(int8(), "[3, 4]"); - ASSERT_OK_AND_ASSIGN(auto src1, StructArray::Make({a1, b1}, fields1)); - - std::vector> fields2 = { - std::make_shared("a", int8(), true), - std::make_shared("b", int8(), true)}; - std::shared_ptr a2, b2; - a2 = ArrayFromJSON(int8(), "[1, 2]"); - b2 = ArrayFromJSON(int8(), "[3, 4]"); - ASSERT_OK_AND_ASSIGN(auto dest1, StructArray::Make({a2, b2}, fields2)); - - CheckCast(src1, dest1); - - // But not the other way around - std::vector> fields3 = { - std::make_shared("a", int8(), true), - std::make_shared("b", int8(), true)}; - std::shared_ptr a3, b3; - a3 = ArrayFromJSON(int8(), "[1, null]"); - b3 = ArrayFromJSON(int8(), "[3, 4]"); - ASSERT_OK_AND_ASSIGN(auto src2, StructArray::Make({a3, b3}, fields3)); - - std::vector> fields4 = { - std::make_shared("a", int8(), false), - std::make_shared("b", int8(), false)}; - const auto dest2 = arrow::struct_(fields4); - const auto options = CastOptions::Safe(dest2); - - EXPECT_RAISES_WITH_MESSAGE_THAT( - TypeError, - ::testing::HasSubstr( - "Type error: cannot cast nullable struct to non-nullable " - "struct: struct struct"), - Cast(src2, options)); +TEST(Cast, StructToDifferentNullabilityStruct) { + { + // OK to go from non-nullable to nullable... + std::vector> fields_src_non_nullable = { + std::make_shared("a", int8(), false), + std::make_shared("b", int8(), false), + std::make_shared("c", int8(), false)}; + std::shared_ptr a_src_non_nullable, b_src_non_nullable, c_src_non_nullable; + a_src_non_nullable = ArrayFromJSON(int8(), "[11, 23, 56]"); + b_src_non_nullable = ArrayFromJSON(int8(), "[32, 46, 37]"); + c_src_non_nullable = ArrayFromJSON(int8(), "[95, 11, 44]"); + ASSERT_OK_AND_ASSIGN( + auto src_non_nullable, + StructArray::Make({a_src_non_nullable, b_src_non_nullable, c_src_non_nullable}, + fields_src_non_nullable)); + + std::shared_ptr a_dest_nullable, b_dest_nullable, c_dest_nullable; + a_dest_nullable = ArrayFromJSON(int64(), "[11, 23, 56]"); + b_dest_nullable = ArrayFromJSON(int64(), "[32, 46, 37]"); + c_dest_nullable = ArrayFromJSON(int64(), "[95, 11, 44]"); + + std::vector> fields_dest1_nullable = { + std::make_shared("a", int64(), true), + std::make_shared("b", int64(), true), + std::make_shared("c", int64(), true)}; + ASSERT_OK_AND_ASSIGN( + auto dest1_nullable, + StructArray::Make({a_dest_nullable, b_dest_nullable, c_dest_nullable}, + fields_dest1_nullable)); + CheckCast(src_non_nullable, dest1_nullable); + + std::vector> fields_dest2_nullable = { + std::make_shared("a", int64(), true), + std::make_shared("c", int64(), true)}; + ASSERT_OK_AND_ASSIGN( + auto dest2_nullable, + StructArray::Make({a_dest_nullable, c_dest_nullable}, fields_dest2_nullable)); + CheckCast(src_non_nullable, dest2_nullable); + + std::vector> fields_dest3_nullable = { + std::make_shared("b", int64(), true)}; + ASSERT_OK_AND_ASSIGN(auto dest3_nullable, + StructArray::Make({b_dest_nullable}, fields_dest3_nullable)); + CheckCast(src_non_nullable, dest3_nullable); + } + { + // But NOT OK to go from nullable to non-nullable... + std::vector> fields_src_nullable = { + std::make_shared("a", int8(), true), + std::make_shared("b", int8(), true), + std::make_shared("c", int8(), true)}; + std::shared_ptr a_src_nullable, b_src_nullable, c_src_nullable; + a_src_nullable = ArrayFromJSON(int8(), "[1, null, 5]"); + b_src_nullable = ArrayFromJSON(int8(), "[3, 4, null]"); + c_src_nullable = ArrayFromJSON(int8(), "[9, 11, 44]"); + ASSERT_OK_AND_ASSIGN( + auto src_nullable, + StructArray::Make({a_src_nullable, b_src_nullable, c_src_nullable}, + fields_src_nullable)); + + std::vector> fields_dest1_non_nullable = { + std::make_shared("a", int64(), false), + std::make_shared("b", int64(), false), + std::make_shared("c", int64(), false)}; + const auto dest1_non_nullable = arrow::struct_(fields_dest1_non_nullable); + const auto options1_non_nullable = CastOptions::Safe(dest1_non_nullable); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("cannot cast nullable field to non-nullable field"), + Cast(src_nullable, options1_non_nullable)); + + std::vector> fields_dest2_non_nullble = { + std::make_shared("a", int64(), false), + std::make_shared("c", int64(), false)}; + const auto dest2_non_nullable = arrow::struct_(fields_dest2_non_nullble); + const auto options2_non_nullable = CastOptions::Safe(dest2_non_nullable); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("cannot cast nullable field to non-nullable field"), + Cast(src_nullable, options2_non_nullable)); + + std::vector> fields_dest3_non_nullble = { + std::make_shared("c", int64(), false)}; + const auto dest3_non_nullable = arrow::struct_(fields_dest3_non_nullble); + const auto options3_non_nullable = CastOptions::Safe(dest3_non_nullable); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("cannot cast nullable field to non-nullable field"), + Cast(src_nullable, options3_non_nullable)); + } } TEST(Cast, IdentityCasts) { diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 7fedb7d3c72..b211ce89947 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -1656,7 +1656,9 @@ TEST(ScanNode, MaterializationOfNestedVirtualColumn) { // TODO(ARROW-1888): allow scanner to "patch up" structs with casts EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT( - TypeError, ::testing::HasSubstr("struct field sizes do not match"), plan.Run()); + TypeError, + ::testing::HasSubstr("struct fields don't match or are in the wrong order"), + plan.Run()); } TEST(ScanNode, MinimalEndToEnd) {