diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc b/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc index ab583bbbe8c..cc36c510363 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc @@ -150,6 +150,84 @@ void AddListCast(CastFunction* func) { DCHECK_OK(func->AddKernel(SrcType::type_id, std::move(kernel))); } +struct CastStruct { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const CastOptions& options = CastState::Get(ctx); + const StructType& in_type = checked_cast(*batch[0].type()); + const StructType& out_type = checked_cast(*out->type()); + const auto in_field_count = in_type.num_fields(); + + if (in_field_count != out_type.num_fields()) { + return Status::TypeError("struct field sizes do not match: ", in_type.ToString(), + " ", out_type.ToString()); + } + + for (int i = 0; i < in_field_count; ++i) { + const auto in_field = in_type.field(i); + const auto out_field = out_type.field(i); + if (in_field->name() != out_field->name()) { + return Status::TypeError("struct field names do not match: ", in_type.ToString(), + " ", out_type.ToString()); + } + + if (in_field->nullable() && !out_field->nullable()) { + return Status::TypeError("cannot cast nullable struct to non-nullable struct: ", + in_type.ToString(), " ", out_type.ToString()); + } + } + + if (out->kind() == Datum::SCALAR) { + const auto& in_scalar = checked_cast(*batch[0].scalar()); + auto out_scalar = checked_cast(out->scalar().get()); + + DCHECK(!out_scalar->is_valid); + if (in_scalar.is_valid) { + for (int i = 0; i < in_field_count; i++) { + auto values = in_scalar.value[i]; + auto target_type = out->type()->field(i)->type(); + ARROW_ASSIGN_OR_RAISE(Datum cast_values, + Cast(values, target_type, options, ctx->exec_context())); + DCHECK_EQ(Datum::SCALAR, cast_values.kind()); + out_scalar->value.push_back(cast_values.scalar()); + } + out_scalar->is_valid = true; + } + return Status::OK(); + } + + const ArrayData& in_array = *batch[0].array(); + ArrayData* out_array = out->mutable_array(); + + if (in_array.buffers[0]) { + ARROW_ASSIGN_OR_RAISE(out_array->buffers[0], + CopyBitmap(ctx->memory_pool(), in_array.buffers[0]->data(), + in_array.offset, in_array.length)); + } + + for (int i = 0; i < in_field_count; ++i) { + auto values = in_array.child_data[i]->Slice(in_array.offset, in_array.length); + auto target_type = out->type()->field(i)->type(); + + ARROW_ASSIGN_OR_RAISE(Datum cast_values, + Cast(values, target_type, options, ctx->exec_context())); + + DCHECK_EQ(Datum::ARRAY, cast_values.kind()); + out_array->child_data.push_back(cast_values.array()); + } + + return Status::OK(); + } +}; + +void AddStructToStructCast(CastFunction* func) { + ScalarKernel kernel; + kernel.exec = CastStruct::Exec; + kernel.signature = + KernelSignature::Make({InputType(StructType::type_id)}, kOutputTargetType); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(StructType::type_id, std::move(kernel))); +} + } // namespace std::vector> GetNestedCasts() { @@ -174,6 +252,7 @@ std::vector> GetNestedCasts() { // So is struct auto cast_struct = std::make_shared("cast_struct", Type::STRUCT); AddCommonCasts(Type::STRUCT, kOutputTargetType, cast_struct.get()); + AddStructToStructCast(cast_struct.get()); // So is dictionary auto cast_dictionary = diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index cb2eb62f6f2..8a4f2c69456 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -2218,6 +2218,117 @@ TEST(Cast, ListToListOptionsPassthru) { } } +static void CheckStructToStruct( + const std::vector>& value_types) { + for (const auto& src_value_type : value_types) { + for (const auto& dest_value_type : value_types) { + std::vector field_names = {"a", "b"}; + std::shared_ptr a1, b1, a2, b2; + a1 = ArrayFromJSON(src_value_type, "[1, 2, 3, 4, null]"); + b1 = ArrayFromJSON(src_value_type, "[null, 7, 8, 9, 0]"); + a2 = ArrayFromJSON(dest_value_type, "[1, 2, 3, 4, null]"); + b2 = ArrayFromJSON(dest_value_type, "[null, 7, 8, 9, 0]"); + ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a1, b1}, field_names)); + ASSERT_OK_AND_ASSIGN(auto dest, StructArray::Make({a2, b2}, field_names)); + + CheckCast(src, dest); + + std::shared_ptr null_bitmap; + BitmapFromVector({0, 1, 0, 1, 0}, &null_bitmap); + + ASSERT_OK_AND_ASSIGN(auto src_nulls, + StructArray::Make({a1, b1}, field_names, null_bitmap)); + ASSERT_OK_AND_ASSIGN(auto dest_nulls, + StructArray::Make({a2, b2}, field_names, null_bitmap)); + CheckCast(src_nulls, dest_nulls); + } + } +} + +TEST(Cast, StructToSameSizedAndNamedStruct) { + CheckStructToStruct({int32(), float32(), int64()}); +} + +TEST(Cast, StructToSameSizedButDifferentNamedStruct) { + std::vector field_names = {"a", "b"}; + std::shared_ptr a, b; + a = ArrayFromJSON(int8(), "[1, 2]"); + b = ArrayFromJSON(int8(), "[3, 4]"); + ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a, b}, field_names)); + + const auto dest = arrow::struct_( + {std::make_shared("c", int8()), std::make_shared("d", int8())}); + const auto options = CastOptions::Safe(dest); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("Type error: struct field names do not match: struct struct"), + Cast(src, options)); +} + +TEST(Cast, StructToDifferentSizeStruct) { + std::vector field_names = {"a", "b"}; + std::shared_ptr a, b; + a = ArrayFromJSON(int8(), "[1, 2]"); + b = ArrayFromJSON(int8(), "[3, 4]"); + ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a, b}, field_names)); + + const auto dest = arrow::struct_({std::make_shared("a", int8()), + std::make_shared("b", int8()), + std::make_shared("c", int8())}); + const auto options = CastOptions::Safe(dest); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("Type error: struct field sizes do not match: struct struct"), + Cast(src, options)); +} + +TEST(Cast, StructToSameSizedButDifferentNullabilityStruct) { + // OK to go from not-nullable to nullable... + std::vector> fields1 = { + std::make_shared("a", int8(), false), + std::make_shared("b", int8(), false)}; + std::shared_ptr a1, b1; + a1 = ArrayFromJSON(int8(), "[1, 2]"); + b1 = ArrayFromJSON(int8(), "[3, 4]"); + ASSERT_OK_AND_ASSIGN(auto src1, StructArray::Make({a1, b1}, fields1)); + + std::vector> fields2 = { + std::make_shared("a", int8(), true), + std::make_shared("b", int8(), true)}; + std::shared_ptr a2, b2; + a2 = ArrayFromJSON(int8(), "[1, 2]"); + b2 = ArrayFromJSON(int8(), "[3, 4]"); + ASSERT_OK_AND_ASSIGN(auto dest1, StructArray::Make({a2, b2}, fields2)); + + CheckCast(src1, dest1); + + // But not the other way around + std::vector> fields3 = { + std::make_shared("a", int8(), true), + std::make_shared("b", int8(), true)}; + std::shared_ptr a3, b3; + a3 = ArrayFromJSON(int8(), "[1, null]"); + b3 = ArrayFromJSON(int8(), "[3, 4]"); + ASSERT_OK_AND_ASSIGN(auto src2, StructArray::Make({a3, b3}, fields3)); + + std::vector> fields4 = { + std::make_shared("a", int8(), false), + std::make_shared("b", int8(), false)}; + const auto dest2 = arrow::struct_(fields4); + const auto options = CastOptions::Safe(dest2); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr( + "Type error: cannot cast nullable struct to non-nullable " + "struct: struct struct"), + Cast(src2, options)); +} + TEST(Cast, IdentityCasts) { // ARROW-4102 auto CheckIdentityCast = [](std::shared_ptr type, const std::string& json) { diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index e62c626d9f8..c37b0666bbb 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -1626,9 +1626,7 @@ TEST(ScanNode, MaterializationOfNestedVirtualColumn) { // TODO(ARROW-1888): allow scanner to "patch up" structs with casts EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT( - NotImplemented, - ::testing::HasSubstr("Unsupported cast from struct to struct"), - plan.Run()); + TypeError, ::testing::HasSubstr("struct field sizes do not match"), plan.Run()); } TEST(ScanNode, MinimalEndToEnd) {