Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions cpp/src/arrow/compute/compute-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,47 @@ TYPED_TEST(TestDictionaryCast, Basic) {
this->CheckPass(*plain_array, *dict_array, dict_array->type(), options);
}*/

TEST_F(TestCast, ListToList) {
CastOptions options;
std::shared_ptr<Array> offsets;

vector<int32_t> offsets_values = {0, 1, 2, 5, 7, 7, 8, 10};
std::vector<bool> offsets_is_valid = {true, true, true, true, false, true, true, true};
ArrayFromVector<Int32Type, int32_t>(offsets_is_valid, offsets_values, &offsets);

shared_ptr<Array> int32_plain_array =
TestBase::MakeRandomArray<typename TypeTraits<Int32Type>::ArrayType>(10, 2);
std::shared_ptr<Array> int32_list_array;
ASSERT_OK(
ListArray::FromArrays(*offsets, *int32_plain_array, pool_, &int32_list_array));

std::shared_ptr<Array> int64_plain_array;
ASSERT_OK(Cast(&this->ctx_, *int32_plain_array, int64(), options, &int64_plain_array));
std::shared_ptr<Array> int64_list_array;
ASSERT_OK(
ListArray::FromArrays(*offsets, *int64_plain_array, pool_, &int64_list_array));

std::shared_ptr<Array> float64_plain_array;
ASSERT_OK(
Cast(&this->ctx_, *int32_plain_array, float64(), options, &float64_plain_array));
std::shared_ptr<Array> float64_list_array;
ASSERT_OK(
ListArray::FromArrays(*offsets, *float64_plain_array, pool_, &float64_list_array));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In here, per below, we will want to verify that the casted result retains zero-copy references to the buffers from the parent array.

There is another nuance to be aware of, which is that the offset member may be non-zero, so in that case simply reusing the buffers is not the appropriate action


this->CheckPass(*int32_list_array, *int64_list_array, int64_list_array->type(),
options);
this->CheckPass(*int32_list_array, *float64_list_array, float64_list_array->type(),
options);
this->CheckPass(*int64_list_array, *int32_list_array, int32_list_array->type(),
options);
this->CheckPass(*int64_list_array, *float64_list_array, float64_list_array->type(),
options);
this->CheckPass(*float64_list_array, *int32_list_array, int32_list_array->type(),
options);
this->CheckPass(*float64_list_array, *int64_list_array, int64_list_array->type(),
options);
}

// ----------------------------------------------------------------------
// Dictionary tests

Expand Down
67 changes: 66 additions & 1 deletion cpp/src/arrow/compute/kernels/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,49 @@ struct CastFunctor<Date32Type, Date64Type> {
}
};

// ----------------------------------------------------------------------
// List to List

class ListCastKernel : public UnaryKernel {
public:
ListCastKernel(std::unique_ptr<UnaryKernel> child_caster,
const std::shared_ptr<DataType>& out_type)
: child_caster_(std::move(child_caster)), out_type_(out_type) {}

Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override {
DCHECK_EQ(Datum::ARRAY, input.kind());

const ArrayData& in_data = *input.array();
DCHECK_EQ(Type::LIST, in_data.type->id());
ArrayData* result;

if (in_data.offset != 0) {
return Status::NotImplemented(
"Casting sliced lists (non-zero offset) not yet implemented");
}

if (out->kind() == Datum::NONE) {
out->value = ArrayData::Make(out_type_, in_data.length);
}

result = out->array().get();

// Copy buffers from parent
result->buffers = in_data.buffers;

Datum casted_child;
RETURN_NOT_OK(child_caster_->Call(ctx, Datum(in_data.child_data[0]), &casted_child));
result->child_data.push_back(casted_child.array());

RETURN_IF_ERROR(ctx);
return Status::OK();
}

private:
std::unique_ptr<UnaryKernel> child_caster_;
std::shared_ptr<DataType> out_type_;
};

// ----------------------------------------------------------------------
// Dictionary to other things

Expand Down Expand Up @@ -895,14 +938,33 @@ GET_CAST_FUNCTION(DATE64_CASES, Date64Type);
GET_CAST_FUNCTION(TIME32_CASES, Time32Type);
GET_CAST_FUNCTION(TIME64_CASES, Time64Type);
GET_CAST_FUNCTION(TIMESTAMP_CASES, TimestampType);

GET_CAST_FUNCTION(DICTIONARY_CASES, DictionaryType);

#define CAST_FUNCTION_CASE(InType) \
case InType::type_id: \
*kernel = Get##InType##CastFunc(out_type, options); \
break

namespace {

Status GetListCastFunc(const DataType& in_type, const std::shared_ptr<DataType>& out_type,
const CastOptions& options, std::unique_ptr<UnaryKernel>* kernel) {
if (out_type->id() != Type::LIST) {
// Kernel will be null
return Status::OK();
}
const DataType& in_value_type = *static_cast<const ListType&>(in_type).value_type();
std::shared_ptr<DataType> out_value_type =
static_cast<const ListType&>(*out_type).value_type();
std::unique_ptr<UnaryKernel> child_caster;
RETURN_NOT_OK(GetCastFunction(in_value_type, out_value_type, options, &child_caster));
*kernel =
std::unique_ptr<UnaryKernel>(new ListCastKernel(std::move(child_caster), out_type));
return Status::OK();
}

} // namespace

Status GetCastFunction(const DataType& in_type, const std::shared_ptr<DataType>& out_type,
const CastOptions& options, std::unique_ptr<UnaryKernel>* kernel) {
switch (in_type.id()) {
Expand All @@ -924,6 +986,9 @@ Status GetCastFunction(const DataType& in_type, const std::shared_ptr<DataType>&
CAST_FUNCTION_CASE(Time64Type);
CAST_FUNCTION_CASE(TimestampType);
CAST_FUNCTION_CASE(DictionaryType);
case Type::LIST:
RETURN_NOT_OK(GetListCastFunc(in_type, out_type, options, kernel));
break;
default:
break;
}
Expand Down