Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_cast_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ namespace internal {

// ----------------------------------------------------------------------

namespace {

template <typename OutT, typename InT>
ARROW_DISABLE_UBSAN("float-cast-overflow")
void DoStaticCast(const void* in_data, int64_t in_offset, int64_t length,
Expand Down Expand Up @@ -117,6 +119,8 @@ void CastNumberImpl(Type::type out_type, const Datum& input, Datum* out) {
}
}

} // namespace

void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Datum& input,
Datum* out) {
switch (in_type) {
Expand Down
141 changes: 98 additions & 43 deletions cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

// Implementation of casting to (or between) list types

#include <limits>
#include <utility>
#include <vector>

Expand All @@ -26,6 +27,7 @@
#include "arrow/compute/kernels/common.h"
#include "arrow/compute/kernels/scalar_cast_internal.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/int_util.h"

namespace arrow {

Expand All @@ -34,82 +36,135 @@ using internal::CopyBitmap;
namespace compute {
namespace internal {

template <typename Type>
Status CastListExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
using offset_type = typename Type::offset_type;
using ScalarType = typename TypeTraits<Type>::ScalarType;
namespace {

const CastOptions& options = CastState::Get(ctx);
// (Large)List<T> -> (Large)List<U>

auto child_type = checked_cast<const Type&>(*out->type()).value_type();
template <typename SrcType, typename DestType>
typename std::enable_if<SrcType::type_id == DestType::type_id, Status>::type
CastListOffsets(KernelContext* ctx, const ArrayData& in_array, ArrayData* out_array) {
return Status::OK();
}

template <typename SrcType, typename DestType>
typename std::enable_if<SrcType::type_id != DestType::type_id, Status>::type
CastListOffsets(KernelContext* ctx, const ArrayData& in_array, ArrayData* out_array) {
using src_offset_type = typename SrcType::offset_type;
using dest_offset_type = typename DestType::offset_type;

ARROW_ASSIGN_OR_RAISE(out_array->buffers[1],
ctx->Allocate(sizeof(dest_offset_type) * (in_array.length + 1)));
::arrow::internal::CastInts(in_array.GetValues<src_offset_type>(1),
out_array->GetMutableValues<dest_offset_type>(1),
in_array.length + 1);
return Status::OK();
}

template <typename SrcType, typename DestType>
struct CastList {
using src_offset_type = typename SrcType::offset_type;
using dest_offset_type = typename DestType::offset_type;

static constexpr bool is_upcast = sizeof(src_offset_type) < sizeof(dest_offset_type);
static constexpr bool is_downcast = sizeof(src_offset_type) > sizeof(dest_offset_type);

static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const CastOptions& options = CastState::Get(ctx);

auto child_type = checked_cast<const DestType&>(*out->type()).value_type();

if (out->kind() == Datum::SCALAR) {
const auto& in_scalar = checked_cast<const ScalarType&>(*batch[0].scalar());
auto out_scalar = checked_cast<ScalarType*>(out->scalar().get());
if (out->kind() == Datum::SCALAR) {
// The scalar case is simple, as only the underlying values must be cast
const auto& in_scalar = checked_cast<const BaseListScalar&>(*batch[0].scalar());
auto out_scalar = checked_cast<BaseListScalar*>(out->scalar().get());

DCHECK(!out_scalar->is_valid);
if (in_scalar.is_valid) {
ARROW_ASSIGN_OR_RAISE(out_scalar->value, Cast(*in_scalar.value, child_type, options,
ctx->exec_context()));
DCHECK(!out_scalar->is_valid);
if (in_scalar.is_valid) {
ARROW_ASSIGN_OR_RAISE(out_scalar->value, Cast(*in_scalar.value, child_type,
options, ctx->exec_context()));

out_scalar->is_valid = true;
out_scalar->is_valid = true;
}
return Status::OK();
}
return Status::OK();
}

const ArrayData& in_array = *batch[0].array();
ArrayData* out_array = out->mutable_array();
const ArrayData& in_array = *batch[0].array();
auto offsets = in_array.GetValues<src_offset_type>(1);
Datum values = in_array.child_data[0];

// Copy from parent
out_array->buffers = in_array.buffers;
Datum values = in_array.child_data[0];
ArrayData* out_array = out->mutable_array();
out_array->buffers = in_array.buffers;

if (in_array.offset != 0) {
if (in_array.buffers[0]) {
// Shift bitmap in case the source offset is non-zero
if (in_array.offset != 0 && in_array.buffers[0]) {
ARROW_ASSIGN_OR_RAISE(out_array->buffers[0],
CopyBitmap(ctx->memory_pool(), in_array.buffers[0]->data(),
in_array.offset, in_array.length));
}
ARROW_ASSIGN_OR_RAISE(out_array->buffers[1],
ctx->Allocate(sizeof(offset_type) * (in_array.length + 1)));

auto offsets = in_array.GetValues<offset_type>(1);
auto shifted_offsets = out_array->GetMutableValues<offset_type>(1);
// Handle list offsets
// Several cases can arise:
// - the source offset is non-zero, in which case we slice the underlying values
// and shift the list offsets (regardless of their respective types)
// - the source offset is zero but source and destination types have
// different list offset types, in which case we cast the list offsets
// - otherwise, we simply keep the original list offsets
if (is_downcast) {
if (offsets[in_array.length] > std::numeric_limits<dest_offset_type>::max()) {
return Status::Invalid("Array of type ", in_array.type->ToString(),
" too large to convert to ", out_array->type->ToString());
}
}

for (int64_t i = 0; i < in_array.length + 1; ++i) {
shifted_offsets[i] = offsets[i] - offsets[0];
if (in_array.offset != 0) {
ARROW_ASSIGN_OR_RAISE(
out_array->buffers[1],
ctx->Allocate(sizeof(dest_offset_type) * (in_array.length + 1)));

auto shifted_offsets = out_array->GetMutableValues<dest_offset_type>(1);
for (int64_t i = 0; i < in_array.length + 1; ++i) {
shifted_offsets[i] = static_cast<dest_offset_type>(offsets[i] - offsets[0]);
}
values = in_array.child_data[0]->Slice(offsets[0], offsets[in_array.length]);
} else {
RETURN_NOT_OK((CastListOffsets<SrcType, DestType>(ctx, in_array, out_array)));
}
values = in_array.child_data[0]->Slice(offsets[0], offsets[in_array.length]);
}

ARROW_ASSIGN_OR_RAISE(Datum cast_values,
Cast(values, child_type, options, ctx->exec_context()));
// Handle values
ARROW_ASSIGN_OR_RAISE(Datum cast_values,
Cast(values, child_type, options, ctx->exec_context()));

DCHECK_EQ(Datum::ARRAY, cast_values.kind());
out_array->child_data.push_back(cast_values.array());
return Status::OK();
}
DCHECK_EQ(Datum::ARRAY, cast_values.kind());
out_array->child_data.push_back(cast_values.array());
return Status::OK();
}
};

template <typename Type>
template <typename SrcType, typename DestType>
void AddListCast(CastFunction* func) {
ScalarKernel kernel;
kernel.exec = CastListExec<Type>;
kernel.signature = KernelSignature::Make({InputType(Type::type_id)}, kOutputTargetType);
kernel.exec = CastList<SrcType, DestType>::Exec;
kernel.signature =
KernelSignature::Make({InputType(SrcType::type_id)}, kOutputTargetType);
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
DCHECK_OK(func->AddKernel(Type::type_id, std::move(kernel)));
DCHECK_OK(func->AddKernel(SrcType::type_id, std::move(kernel)));
}

} // namespace

std::vector<std::shared_ptr<CastFunction>> GetNestedCasts() {
// We use the list<T> from the CastOptions when resolving the output type

auto cast_list = std::make_shared<CastFunction>("cast_list", Type::LIST);
AddCommonCasts(Type::LIST, kOutputTargetType, cast_list.get());
AddListCast<ListType>(cast_list.get());
AddListCast<ListType, ListType>(cast_list.get());
AddListCast<LargeListType, ListType>(cast_list.get());

auto cast_large_list =
std::make_shared<CastFunction>("cast_large_list", Type::LARGE_LIST);
AddCommonCasts(Type::LARGE_LIST, kOutputTargetType, cast_large_list.get());
AddListCast<LargeListType>(cast_large_list.get());
AddListCast<ListType, LargeListType>(cast_large_list.get());
AddListCast<LargeListType, LargeListType>(cast_large_list.get());

// FSL is a bit incomplete at the moment
auto cast_fsl =
Expand Down
79 changes: 37 additions & 42 deletions cpp/src/arrow/compute/kernels/scalar_cast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1676,56 +1676,51 @@ TEST(Cast, ListToPrimitive) {
Cast(*ArrayFromJSON(list(binary()), R"([["1", "2"], ["3", "4"]])"), utf8()));
}

TEST(Cast, ListToList) {
using make_list_t = std::shared_ptr<DataType>(const std::shared_ptr<DataType>&);
for (auto make_list : std::vector<make_list_t*>{&list, &large_list}) {
auto list_int32 =
ArrayFromJSON(make_list(int32()),
"[[0], [1], null, [2, 3, 4], [5, 6], null, [], [7], [8, 9]]")
->data();

auto list_int64 = list_int32->Copy();
list_int64->type = make_list(int64());
list_int64->child_data[0] = Cast(list_int32->child_data[0], int64())->array();
ValidateOutput(*list_int64);

auto list_float32 = list_int32->Copy();
list_float32->type = make_list(float32());
list_float32->child_data[0] = Cast(list_int32->child_data[0], float32())->array();
ValidateOutput(*list_float32);

CheckCast(MakeArray(list_int32), MakeArray(list_float32));
CheckCast(MakeArray(list_float32), MakeArray(list_int64));
CheckCast(MakeArray(list_int64), MakeArray(list_float32));

CheckCast(MakeArray(list_int32), MakeArray(list_int64));
CheckCast(MakeArray(list_float32), MakeArray(list_int32));
CheckCast(MakeArray(list_int64), MakeArray(list_int32));
using make_list_t = std::shared_ptr<DataType>(const std::shared_ptr<DataType>&);

static const auto list_factories = std::vector<make_list_t*>{&list, &large_list};

static void CheckListToList(const std::vector<std::shared_ptr<DataType>>& value_types,
const std::string& json_data) {
for (auto make_src_list : list_factories) {
for (auto make_dest_list : list_factories) {
for (const auto& src_value_type : value_types) {
for (const auto& dest_value_type : value_types) {
const auto src_type = make_src_list(src_value_type);
const auto dest_type = make_dest_list(dest_value_type);
ARROW_SCOPED_TRACE("src_type = ", src_type->ToString(),
", dest_type = ", dest_type->ToString());
CheckCast(ArrayFromJSON(src_type, json_data),
ArrayFromJSON(dest_type, json_data));
}
}
}
}
}

// No nulls (ARROW-12568)
for (auto make_list : std::vector<make_list_t*>{&list, &large_list}) {
auto list_int32 = ArrayFromJSON(make_list(int32()),
"[[0], [1], [2, 3, 4], [5, 6], [], [7], [8, 9]]")
->data();
auto list_int64 = list_int32->Copy();
list_int64->type = make_list(int64());
list_int64->child_data[0] = Cast(list_int32->child_data[0], int64())->array();
ValidateOutput(*list_int64);
TEST(Cast, ListToList) {
CheckListToList({int32(), float32(), int64()},
"[[0], [1], null, [2, 3, 4], [5, 6], null, [], [7], [8, 9]]");
}

CheckCast(MakeArray(list_int32), MakeArray(list_int64));
CheckCast(MakeArray(list_int64), MakeArray(list_int32));
}
TEST(Cast, ListToListNoNulls) {
// ARROW-12568
CheckListToList({int32(), float32(), int64()},
"[[0], [1], [2, 3, 4], [5, 6], [], [7], [8, 9]]");
}

TEST(Cast, ListToListOptionsPassthru) {
auto list_int32 = ArrayFromJSON(list(int32()), "[[87654321]]");
for (auto make_src_list : list_factories) {
for (auto make_dest_list : list_factories) {
auto list_int32 = ArrayFromJSON(make_src_list(int32()), "[[87654321]]");

auto options = CastOptions::Safe(list(int16()));
CheckCastFails(list_int32, options);
auto options = CastOptions::Safe(make_dest_list(int16()));
CheckCastFails(list_int32, options);

options.allow_int_overflow = true;
CheckCast(list_int32, ArrayFromJSON(list(int16()), "[[32689]]"), options);
options.allow_int_overflow = true;
CheckCast(list_int32, ArrayFromJSON(make_dest_list(int16()), "[[32689]]"), options);
}
}
}

TEST(Cast, IdentityCasts) {
Expand Down
5 changes: 5 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,11 @@ def test_cast():
expected = pa.array([1262304000000, 1420070400000], type='timestamp[ms]')
assert pc.cast(arr, 'timestamp[ms]') == expected

arr = pa.array([[1, 2], [3, 4, 5]], type=pa.large_list(pa.int8()))
expected = pa.array([["1", "2"], ["3", "4", "5"]],
type=pa.list_(pa.utf8()))
assert pc.cast(arr, expected.type) == expected


def test_strptime():
arr = pa.array(["5/1/2020", None, "12/13/1900"])
Expand Down