Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ void AddTypeToTypeCast(CastFunction* func) {
kernel.exec = CastFunctor::Exec;
kernel.signature = KernelSignature::Make({InputType(SrcT::type_id)}, kOutputTargetType);
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
DCHECK_OK(func->AddKernel(StructType::type_id, std::move(kernel)));
DCHECK_OK(func->AddKernel(SrcT::type_id, std::move(kernel)));
}

template <typename DestType>
Expand Down Expand Up @@ -480,14 +480,18 @@ std::vector<std::shared_ptr<CastFunction>> GetNestedCasts() {
auto cast_list = std::make_shared<CastFunction>("cast_list", Type::LIST);
AddCommonCasts(Type::LIST, kOutputTargetType, cast_list.get());
AddListCast<ListType, ListType>(cast_list.get());
AddListCast<ListViewType, ListType>(cast_list.get());
AddListCast<LargeListType, ListType>(cast_list.get());
AddListCast<LargeListViewType, ListType>(cast_list.get());
AddTypeToTypeCast<CastFixedToVarList<ListType>, FixedSizeListType>(cast_list.get());

auto cast_large_list =
std::make_shared<CastFunction>("cast_large_list", Type::LARGE_LIST);
AddCommonCasts(Type::LARGE_LIST, kOutputTargetType, cast_large_list.get());
AddListCast<ListType, LargeListType>(cast_large_list.get());
AddListCast<ListViewType, LargeListType>(cast_large_list.get());
AddListCast<LargeListType, LargeListType>(cast_large_list.get());
AddListCast<LargeListViewType, LargeListType>(cast_large_list.get());
AddTypeToTypeCast<CastFixedToVarList<LargeListType>, FixedSizeListType>(
cast_large_list.get());

Expand All @@ -503,7 +507,11 @@ std::vector<std::shared_ptr<CastFunction>> GetNestedCasts() {
AddCommonCasts(Type::FIXED_SIZE_LIST, kOutputTargetType, cast_fsl.get());
AddTypeToTypeCast<CastFixedList, FixedSizeListType>(cast_fsl.get());
AddTypeToTypeCast<CastVarToFixedList<ListType>, ListType>(cast_fsl.get());
AddTypeToTypeCast<CastVarToFixedList<ListViewType>, ListViewType>(cast_fsl.get());
AddTypeToTypeCast<CastVarToFixedList<LargeListType>, LargeListType>(cast_fsl.get());
AddTypeToTypeCast<CastVarToFixedList<LargeListViewType>, LargeListViewType>(
cast_fsl.get());
AddTypeToTypeCast<CastVarToFixedList<ListType>, MapType>(cast_fsl.get());

// So is struct
auto cast_struct = std::make_shared<CastFunction>("cast_struct", Type::STRUCT);
Expand Down
39 changes: 25 additions & 14 deletions cpp/src/arrow/scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,8 @@ std::shared_ptr<DataType> MakeListType<FixedSizeListType>(

template <typename ScalarType>
void CheckListCast(const ScalarType& scalar, const std::shared_ptr<DataType>& to_type) {
EXPECT_OK_AND_ASSIGN(auto cast_scalar, scalar.CastTo(to_type));
EXPECT_OK_AND_ASSIGN(auto cast_scalar_datum, Cast(scalar, to_type));
const auto& cast_scalar = cast_scalar_datum.scalar();
ASSERT_OK(cast_scalar->ValidateFull());
ASSERT_EQ(*cast_scalar->type, *to_type);

Expand All @@ -1087,11 +1088,25 @@ void CheckListCast(const ScalarType& scalar, const std::shared_ptr<DataType>& to
*checked_cast<const BaseListScalar&>(*cast_scalar).value);
}

void CheckInvalidListCast(const Scalar& scalar, const std::shared_ptr<DataType>& to_type,
const std::string& expected_message) {
EXPECT_RAISES_WITH_CODE_AND_MESSAGE_THAT(StatusCode::Invalid,
::testing::HasSubstr(expected_message),
scalar.CastTo(to_type));
template <typename ScalarType>
void CheckListCastError(const ScalarType& scalar,
const std::shared_ptr<DataType>& to_type) {
StatusCode code;
std::string expected_message;
if (scalar.type->id() == Type::FIXED_SIZE_LIST) {
code = StatusCode::TypeError;
expected_message =
"Size of FixedSizeList is not the same. input list: " + scalar.type->ToString() +
" output list: " + to_type->ToString();
} else {
code = StatusCode::Invalid;
expected_message =
"ListType can only be casted to FixedSizeListType if the lists are all the "
"expected size.";
}

EXPECT_RAISES_WITH_CODE_AND_MESSAGE_THAT(code, ::testing::HasSubstr(expected_message),
Cast(scalar, to_type));
}

template <typename T>
Expand Down Expand Up @@ -1178,10 +1193,8 @@ class TestListLikeScalar : public ::testing::Test {
CheckListCast(
scalar, fixed_size_list(value_->type(), static_cast<int32_t>(value_->length())));

CheckInvalidListCast(scalar, fixed_size_list(value_->type(), 5),
"Cannot cast " + scalar.type->ToString() + " of length " +
std::to_string(value_->length()) +
" to fixed size list of length 5");
auto invalid_cast_type = fixed_size_list(value_->type(), 5);
CheckListCastError(scalar, invalid_cast_type);
}

protected:
Expand Down Expand Up @@ -1238,10 +1251,8 @@ TEST(TestMapScalar, Cast) {
CheckListCast(scalar, large_list(key_value_type));
CheckListCast(scalar, fixed_size_list(key_value_type, 2));

CheckInvalidListCast(scalar, fixed_size_list(key_value_type, 5),
"Cannot cast " + scalar.type->ToString() + " of length " +
std::to_string(value->length()) +
" to fixed size list of length 5");
auto invalid_cast_type = fixed_size_list(key_value_type, 5);
CheckListCastError(scalar, invalid_cast_type);
}

TEST(TestStructScalar, FieldAccess) {
Expand Down