diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index c0b68539e80..48670bf0452 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -1127,6 +1128,62 @@ class ZeroCopyCast : public CastKernelBase { } }; +class ExtensionCastKernel : public CastKernelBase { + public: + static Status Make(const DataType& in_type, std::shared_ptr out_type, + const CastOptions& options, + std::unique_ptr* kernel) { + const auto storage_type = checked_cast(in_type).storage_type(); + + std::unique_ptr storage_caster; + RETURN_NOT_OK(GetCastFunction(*storage_type, out_type, options, &storage_caster)); + kernel->reset( + new ExtensionCastKernel(std::move(storage_caster), std::move(out_type))); + + return Status::OK(); + } + + Status Init(const DataType& in_type) override { + auto& type = checked_cast(in_type); + storage_type_ = type.storage_type(); + extension_name_ = type.extension_name(); + return Status::OK(); + } + + Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override { + DCHECK_EQ(input.kind(), Datum::ARRAY); + + // validate: type is the same as the type the kernel was constructed with + const auto& input_type = checked_cast(*input.type()); + if (input_type.extension_name() != extension_name_) { + return Status::TypeError( + "The cast kernel was constructed to cast from the extension type named '", + extension_name_, "' but input has extension type named '", + input_type.extension_name(), "'"); + } + if (!input_type.storage_type()->Equals(storage_type_)) { + return Status::TypeError("The cast kernel was constructed with a storage type: ", + storage_type_->ToString(), + ", but it is called with a different storage type:", + input_type.storage_type()->ToString()); + } + + // construct an ArrayData object with the underlying storage type + auto new_input = input.array()->Copy(); + new_input->type = storage_type_; + return InvokeWithAllocation(ctx, storage_caster_.get(), new_input, out); + } + + protected: + ExtensionCastKernel(std::unique_ptr storage_caster, + std::shared_ptr out_type) + : CastKernelBase(std::move(out_type)), storage_caster_(std::move(storage_caster)) {} + + std::string extension_name_; + std::shared_ptr storage_type_; + std::unique_ptr storage_caster_; +}; + class CastKernel : public CastKernelBase { public: CastKernel(const CastOptions& options, const CastFunction& func, @@ -1275,11 +1332,6 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr out_ty return Status::OK(); } - if (in_type.id() == Type::NA) { - kernel->reset(new FromNullCastKernel(std::move(out_type))); - return Status::OK(); - } - std::unique_ptr cast_kernel; switch (in_type.id()) { CAST_FUNCTION_CASE(BooleanType); @@ -1304,6 +1356,9 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr out_ty CAST_FUNCTION_CASE(LargeBinaryType); CAST_FUNCTION_CASE(LargeStringType); CAST_FUNCTION_CASE(DictionaryType); + case Type::NA: + cast_kernel.reset(new FromNullCastKernel(std::move(out_type))); + break; case Type::LIST: RETURN_NOT_OK( GetListCastFunc(in_type, std::move(out_type), options, &cast_kernel)); @@ -1312,6 +1367,10 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr out_ty RETURN_NOT_OK(GetListCastFunc(in_type, std::move(out_type), options, &cast_kernel)); break; + case Type::EXTENSION: + RETURN_NOT_OK(ExtensionCastKernel::Make(std::move(in_type), std::move(out_type), + options, &cast_kernel)); + break; default: break; } diff --git a/cpp/src/arrow/compute/kernels/cast_test.cc b/cpp/src/arrow/compute/kernels/cast_test.cc index 7198a10c4d1..65ae570f50c 100644 --- a/cpp/src/arrow/compute/kernels/cast_test.cc +++ b/cpp/src/arrow/compute/kernels/cast_test.cc @@ -26,9 +26,11 @@ #include "arrow/array.h" #include "arrow/buffer.h" +#include "arrow/extension_type.h" #include "arrow/memory_pool.h" #include "arrow/status.h" #include "arrow/table.h" +#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" @@ -1480,5 +1482,54 @@ TYPED_TEST(TestDictionaryCast, OutTypeError) { this->CheckPass(*plain_array, *dict_array, dict_array->type(), options); }*/ +std::shared_ptr SmallintArrayFromJSON(const std::string& json_data) { + auto arr = ArrayFromJSON(int16(), json_data); + auto ext_data = arr->data()->Copy(); + ext_data->type = smallint(); + return MakeArray(ext_data); +} + +TEST_F(TestCast, ExtensionTypeToIntDowncast) { + auto smallint = std::make_shared(); + ASSERT_OK(RegisterExtensionType(smallint)); + + CastOptions options; + options.allow_int_overflow = false; + + std::shared_ptr result; + std::vector is_valid = {true, false, true, true, true}; + + // Smallint(int16) to int16 + auto v0 = SmallintArrayFromJSON("[0, 100, 200, 1, 2]"); + CheckZeroCopy(*v0, int16()); + + // Smallint(int16) to uint8, no overflow/underrun + auto v1 = SmallintArrayFromJSON("[0, 100, 200, 1, 2]"); + auto e1 = ArrayFromJSON(uint8(), "[0, 100, 200, 1, 2]"); + CheckPass(*v1, *e1, uint8(), options); + + // Smallint(int16) to uint8, with overflow + auto v2 = SmallintArrayFromJSON("[0, null, 256, 1, 3]"); + auto e2 = ArrayFromJSON(uint8(), "[0, null, 0, 1, 3]"); + // allow overflow + options.allow_int_overflow = true; + CheckPass(*v2, *e2, uint8(), options); + // disallow overflow + options.allow_int_overflow = false; + ASSERT_RAISES(Invalid, Cast(&ctx_, *v2, uint8(), options, &result)); + + // Smallint(int16) to uint8, with underflow + auto v3 = SmallintArrayFromJSON("[0, null, -1, 1, 0]"); + auto e3 = ArrayFromJSON(uint8(), "[0, null, 255, 1, 0]"); + // allow overflow + options.allow_int_overflow = true; + CheckPass(*v3, *e3, uint8(), options); + // disallow overflow + options.allow_int_overflow = false; + ASSERT_RAISES(Invalid, Cast(&ctx_, *v3, uint8(), options, &result)); + + ASSERT_OK(UnregisterExtensionType("smallint")); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/testing/extension_type.h b/cpp/src/arrow/testing/extension_type.h index d29ddccdc5d..0384a28f343 100644 --- a/cpp/src/arrow/testing/extension_type.h +++ b/cpp/src/arrow/testing/extension_type.h @@ -48,10 +48,38 @@ class ARROW_EXPORT UUIDType : public ExtensionType { std::string Serialize() const override { return "uuid-type-unique-code"; } }; +class ARROW_EXPORT SmallintArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +class ARROW_EXPORT SmallintType : public ExtensionType { + public: + SmallintType() : ExtensionType(int16()) {} + + std::string extension_name() const override { return "smallint"; } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + Status Deserialize(std::shared_ptr storage_type, + const std::string& serialized, + std::shared_ptr* out) const override; + + std::string Serialize() const override { return "smallint"; } +}; + ARROW_EXPORT std::shared_ptr uuid(); +ARROW_EXPORT +std::shared_ptr smallint(); + ARROW_EXPORT std::shared_ptr ExampleUUID(); +ARROW_EXPORT +std::shared_ptr ExampleSmallint(); + } // namespace arrow diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 8caf3f1cec9..009ee80b8a7 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -382,11 +382,7 @@ void SleepFor(double seconds) { // Extension types bool UUIDType::ExtensionEquals(const ExtensionType& other) const { - const auto& other_ext = static_cast(other); - if (other_ext.extension_name() != this->extension_name()) { - return false; - } - return true; + return (other.extension_name() == this->extension_name()); } std::shared_ptr UUIDType::MakeArray(std::shared_ptr data) const { @@ -423,4 +419,38 @@ std::shared_ptr ExampleUUID() { return MakeArray(ext_data); } +bool SmallintType::ExtensionEquals(const ExtensionType& other) const { + return (other.extension_name() == this->extension_name()); +} + +std::shared_ptr SmallintType::MakeArray(std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("smallint", static_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +Status SmallintType::Deserialize(std::shared_ptr storage_type, + const std::string& serialized, + std::shared_ptr* out) const { + if (serialized != "smallint") { + return Status::Invalid("Type identifier did not match"); + } + if (!storage_type->Equals(*int16())) { + return Status::Invalid("Invalid storage type for SmallintType"); + } + *out = std::make_shared(); + return Status::OK(); +} + +std::shared_ptr smallint() { return std::make_shared(); } + +std::shared_ptr ExampleSmallint() { + auto storage_type = int16(); + auto ext_type = smallint(); + auto arr = ArrayFromJSON(storage_type, "[-32768, null, 1, 2, 3, 4, 32767]"); + auto ext_data = arr->data()->Copy(); + ext_data->type = ext_type; + return MakeArray(ext_data); +} + } // namespace arrow diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 64d28e19b47..0d07ac143ea 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -24,6 +24,15 @@ import pytest +class IntegerType(pa.PyExtensionType): + + def __init__(self): + pa.PyExtensionType.__init__(self, pa.int64()) + + def __reduce__(self): + return IntegerType, () + + class UuidType(pa.PyExtensionType): def __init__(self): @@ -168,6 +177,42 @@ def test_ext_array_pickling(): assert arr.storage.to_pylist() == [b"foo", b"bar"] +def test_cast_kernel_on_extension_arrays(): + # test array casting + storage = pa.array([1, 2, 3, 4], pa.int64()) + arr = pa.ExtensionArray.from_storage(IntegerType(), storage) + + # test that no allocation happens during identity cast + allocated_before_cast = pa.total_allocated_bytes() + casted = arr.cast(pa.int64()) + assert pa.total_allocated_bytes() == allocated_before_cast + + cases = [ + (pa.int64(), pa.Int64Array), + (pa.int32(), pa.Int32Array), + (pa.int16(), pa.Int16Array), + (pa.uint64(), pa.UInt64Array), + (pa.uint32(), pa.UInt32Array), + (pa.uint16(), pa.UInt16Array) + ] + for typ, klass in cases: + casted = arr.cast(typ) + assert casted.type == typ + assert isinstance(casted, klass) + + # test chunked array casting + arr = pa.chunked_array([arr, arr]) + casted = arr.cast(pa.int16()) + assert casted.type == pa.int16() + assert isinstance(casted, pa.ChunkedArray) + + +def test_casting_to_extension_type_raises(): + arr = pa.array([1, 2, 3, 4], pa.int64()) + with pytest.raises(pa.ArrowNotImplementedError): + arr.cast(IntegerType()) + + def example_batch(): ty = ParamExtType(3) storage = pa.array([b"foo", b"bar"], type=pa.binary(3))