diff --git a/cpp/src/arrow/compute/compute-test.cc b/cpp/src/arrow/compute/compute-test.cc index 84af8f7c6b0..3fc15018630 100644 --- a/cpp/src/arrow/compute/compute-test.cc +++ b/cpp/src/arrow/compute/compute-test.cc @@ -801,6 +801,47 @@ TYPED_TEST(TestDictionaryCast, Basic) { this->CheckPass(*plain_array, *dict_array, dict_array->type(), options); }*/ +TEST_F(TestCast, ListToList) { + CastOptions options; + std::shared_ptr offsets; + + vector offsets_values = {0, 1, 2, 5, 7, 7, 8, 10}; + std::vector offsets_is_valid = {true, true, true, true, false, true, true, true}; + ArrayFromVector(offsets_is_valid, offsets_values, &offsets); + + shared_ptr int32_plain_array = + TestBase::MakeRandomArray::ArrayType>(10, 2); + std::shared_ptr int32_list_array; + ASSERT_OK( + ListArray::FromArrays(*offsets, *int32_plain_array, pool_, &int32_list_array)); + + std::shared_ptr int64_plain_array; + ASSERT_OK(Cast(&this->ctx_, *int32_plain_array, int64(), options, &int64_plain_array)); + std::shared_ptr int64_list_array; + ASSERT_OK( + ListArray::FromArrays(*offsets, *int64_plain_array, pool_, &int64_list_array)); + + std::shared_ptr float64_plain_array; + ASSERT_OK( + Cast(&this->ctx_, *int32_plain_array, float64(), options, &float64_plain_array)); + std::shared_ptr float64_list_array; + ASSERT_OK( + ListArray::FromArrays(*offsets, *float64_plain_array, pool_, &float64_list_array)); + + this->CheckPass(*int32_list_array, *int64_list_array, int64_list_array->type(), + options); + this->CheckPass(*int32_list_array, *float64_list_array, float64_list_array->type(), + options); + this->CheckPass(*int64_list_array, *int32_list_array, int32_list_array->type(), + options); + this->CheckPass(*int64_list_array, *float64_list_array, float64_list_array->type(), + options); + this->CheckPass(*float64_list_array, *int32_list_array, int32_list_array->type(), + options); + this->CheckPass(*float64_list_array, *int64_list_array, int64_list_array->type(), + options); +} + // ---------------------------------------------------------------------- // Dictionary tests diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 465be958cfa..afa05485f65 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -460,6 +460,49 @@ struct CastFunctor { } }; +// ---------------------------------------------------------------------- +// List to List + +class ListCastKernel : public UnaryKernel { + public: + ListCastKernel(std::unique_ptr child_caster, + const std::shared_ptr& out_type) + : child_caster_(std::move(child_caster)), out_type_(out_type) {} + + Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override { + DCHECK_EQ(Datum::ARRAY, input.kind()); + + const ArrayData& in_data = *input.array(); + DCHECK_EQ(Type::LIST, in_data.type->id()); + ArrayData* result; + + if (in_data.offset != 0) { + return Status::NotImplemented( + "Casting sliced lists (non-zero offset) not yet implemented"); + } + + if (out->kind() == Datum::NONE) { + out->value = ArrayData::Make(out_type_, in_data.length); + } + + result = out->array().get(); + + // Copy buffers from parent + result->buffers = in_data.buffers; + + Datum casted_child; + RETURN_NOT_OK(child_caster_->Call(ctx, Datum(in_data.child_data[0]), &casted_child)); + result->child_data.push_back(casted_child.array()); + + RETURN_IF_ERROR(ctx); + return Status::OK(); + } + + private: + std::unique_ptr child_caster_; + std::shared_ptr out_type_; +}; + // ---------------------------------------------------------------------- // Dictionary to other things @@ -895,7 +938,6 @@ GET_CAST_FUNCTION(DATE64_CASES, Date64Type); GET_CAST_FUNCTION(TIME32_CASES, Time32Type); GET_CAST_FUNCTION(TIME64_CASES, Time64Type); GET_CAST_FUNCTION(TIMESTAMP_CASES, TimestampType); - GET_CAST_FUNCTION(DICTIONARY_CASES, DictionaryType); #define CAST_FUNCTION_CASE(InType) \ @@ -903,6 +945,26 @@ GET_CAST_FUNCTION(DICTIONARY_CASES, DictionaryType); *kernel = Get##InType##CastFunc(out_type, options); \ break +namespace { + +Status GetListCastFunc(const DataType& in_type, const std::shared_ptr& out_type, + const CastOptions& options, std::unique_ptr* kernel) { + if (out_type->id() != Type::LIST) { + // Kernel will be null + return Status::OK(); + } + const DataType& in_value_type = *static_cast(in_type).value_type(); + std::shared_ptr out_value_type = + static_cast(*out_type).value_type(); + std::unique_ptr child_caster; + RETURN_NOT_OK(GetCastFunction(in_value_type, out_value_type, options, &child_caster)); + *kernel = + std::unique_ptr(new ListCastKernel(std::move(child_caster), out_type)); + return Status::OK(); +} + +} // namespace + Status GetCastFunction(const DataType& in_type, const std::shared_ptr& out_type, const CastOptions& options, std::unique_ptr* kernel) { switch (in_type.id()) { @@ -924,6 +986,9 @@ Status GetCastFunction(const DataType& in_type, const std::shared_ptr& CAST_FUNCTION_CASE(Time64Type); CAST_FUNCTION_CASE(TimestampType); CAST_FUNCTION_CASE(DictionaryType); + case Type::LIST: + RETURN_NOT_OK(GetListCastFunc(in_type, out_type, options, kernel)); + break; default: break; }