From 6891ba887506e983ef01ab00654c8dafb762b39b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 16 Mar 2020 15:39:34 +0100 Subject: [PATCH 1/9] support casting extension types --- cpp/src/arrow/compute/kernels/cast.cc | 3 +++ python/pyarrow/tests/test_extension_type.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index c0b68539e80..1b94a35fb24 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -1278,6 +1278,9 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr out_ty if (in_type.id() == Type::NA) { kernel->reset(new FromNullCastKernel(std::move(out_type))); return Status::OK(); + } else if (in_type.id() == Type::EXTENSION) { + auto storage_type = dynamic_cast(in_type).storage_type(); + return GetCastFunction(*storage_type, out_type, options, kernel); } std::unique_ptr cast_kernel; diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 64d28e19b47..bb8f78a5578 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -24,6 +24,15 @@ import pytest +class IntegerType(pa.PyExtensionType): + + def __init__(self): + pa.PyExtensionType.__init__(self, pa.int64()) + + def __reduce__(self): + return IntegerType, () + + class UuidType(pa.PyExtensionType): def __init__(self): @@ -168,6 +177,17 @@ def test_ext_array_pickling(): assert arr.storage.to_pylist() == [b"foo", b"bar"] +def test_cast_kernel_on_extension_arrays(): + # test array casting + storage = pa.array([1, 2, 3, 4], pa.int64()) + arr = pa.ExtensionArray.from_storage(IntegerType(), storage) + assert arr.cast(pa.int32()).type == pa.int32() + + # test chunked array casting + arr = pa.chunked_array([arr, arr]) + assert arr.cast(pa.int16()).type == pa.int16() + + def example_batch(): ty = ParamExtType(3) storage = pa.array([b"foo", b"bar"], type=pa.binary(3)) From 8473b24f19f0af3782d3373bce3bb3a16c741861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 17 Mar 2020 12:03:55 +0100 Subject: [PATCH 2/9] add ExtensionCastKernel --- cpp/src/arrow/compute/kernels/cast.cc | 51 +++++++++++++++++---- python/pyarrow/tests/test_extension_type.py | 24 +++++++++- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 1b94a35fb24..7cdec94f7ed 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -1127,6 +1127,42 @@ class ZeroCopyCast : public CastKernelBase { } }; +class ExtensionCastKernel : public CastKernelBase { + public: + static Status Make(const DataType& in_type, std::shared_ptr out_type, + const CastOptions& options, + std::unique_ptr* kernel) { + if (in_type.id() != Type::EXTENSION) { + return Status::TypeError("Not an extension type"); + } + auto storage_type = checked_cast(in_type).storage_type(); + + std::unique_ptr storage_caster; + RETURN_NOT_OK(GetCastFunction(*storage_type, out_type, options, &storage_caster)); + kernel->reset( + new ExtensionCastKernel(std::move(storage_caster), std::move(out_type))); + + return Status::OK(); + } + + Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override { + DCHECK_EQ(input.kind(), Datum::ARRAY); + if (input.type()->id() != Type::EXTENSION) { + return Status::TypeError("Not an extension type"); + } + auto new_input = input.array()->Copy(); + new_input->type = checked_cast(*new_input->type).storage_type(); + return InvokeWithAllocation(ctx, storage_caster_.get(), new_input, out); + } + + protected: + ExtensionCastKernel(std::unique_ptr storage_caster, + std::shared_ptr out_type) + : CastKernelBase(std::move(out_type)), storage_caster_(std::move(storage_caster)) {} + + std::unique_ptr storage_caster_; +}; + class CastKernel : public CastKernelBase { public: CastKernel(const CastOptions& options, const CastFunction& func, @@ -1275,14 +1311,6 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr out_ty return Status::OK(); } - if (in_type.id() == Type::NA) { - kernel->reset(new FromNullCastKernel(std::move(out_type))); - return Status::OK(); - } else if (in_type.id() == Type::EXTENSION) { - auto storage_type = dynamic_cast(in_type).storage_type(); - return GetCastFunction(*storage_type, out_type, options, kernel); - } - std::unique_ptr cast_kernel; switch (in_type.id()) { CAST_FUNCTION_CASE(BooleanType); @@ -1307,6 +1335,9 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr out_ty CAST_FUNCTION_CASE(LargeBinaryType); CAST_FUNCTION_CASE(LargeStringType); CAST_FUNCTION_CASE(DictionaryType); + case Type::NA: + cast_kernel.reset(new FromNullCastKernel(std::move(out_type))); + break; case Type::LIST: RETURN_NOT_OK( GetListCastFunc(in_type, std::move(out_type), options, &cast_kernel)); @@ -1315,6 +1346,10 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr out_ty RETURN_NOT_OK(GetListCastFunc(in_type, std::move(out_type), options, &cast_kernel)); break; + case Type::EXTENSION: + RETURN_NOT_OK(ExtensionCastKernel::Make(std::move(in_type), std::move(out_type), + options, &cast_kernel)); + break; default: break; } diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index bb8f78a5578..05fabe4587b 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -181,11 +181,31 @@ def test_cast_kernel_on_extension_arrays(): # test array casting storage = pa.array([1, 2, 3, 4], pa.int64()) arr = pa.ExtensionArray.from_storage(IntegerType(), storage) - assert arr.cast(pa.int32()).type == pa.int32() + + cases = [ + (pa.int64(), pa.Int64Array), + (pa.int32(), pa.Int32Array), + (pa.int16(), pa.Int16Array), + (pa.uint64(), pa.UInt64Array), + (pa.uint32(), pa.UInt32Array), + (pa.uint16(), pa.UInt16Array) + ] + for typ, klass in cases: + casted = arr.cast(typ) + assert casted.type == typ + assert isinstance(casted, klass) # test chunked array casting arr = pa.chunked_array([arr, arr]) - assert arr.cast(pa.int16()).type == pa.int16() + casted = arr.cast(pa.int16()) + assert casted.type == pa.int16() + assert isinstance(casted, pa.ChunkedArray) + + +def test_casting_to_extension_type_raises(): + arr = pa.array([1, 2, 3, 4], pa.int64()) + with pytest.raises(pa.ArrowNotImplementedError): + arr.cast(IntegerType()) def example_batch(): From 600d7b0a1f56f0f87838c7287ad6e5c2023269e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 17 Mar 2020 20:16:25 +0100 Subject: [PATCH 3/9] C++ test --- cpp/src/arrow/compute/kernels/cast_test.cc | 47 ++++++++++++++++++++++ cpp/src/arrow/testing/extension_type.h | 28 +++++++++++++ cpp/src/arrow/testing/gtest_util.cc | 40 +++++++++++++++--- 3 files changed, 110 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/cast_test.cc b/cpp/src/arrow/compute/kernels/cast_test.cc index 7198a10c4d1..bc01d1997c8 100644 --- a/cpp/src/arrow/compute/kernels/cast_test.cc +++ b/cpp/src/arrow/compute/kernels/cast_test.cc @@ -26,9 +26,11 @@ #include "arrow/array.h" #include "arrow/buffer.h" +#include "arrow/extension_type.h" #include "arrow/memory_pool.h" #include "arrow/status.h" #include "arrow/table.h" +#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" @@ -1480,5 +1482,50 @@ TYPED_TEST(TestDictionaryCast, OutTypeError) { this->CheckPass(*plain_array, *dict_array, dict_array->type(), options); }*/ +std::shared_ptr SmallintArrayFromJSON(const std::string& json_data) { + auto arr = ArrayFromJSON(int16(), json_data); + auto ext_data = arr->data()->Copy(); + ext_data->type = smallint(); + return MakeArray(ext_data); +} + +TEST_F(TestCast, ExtensionTypeToIntDowncast) { + auto smallint = std::make_shared(); + ASSERT_OK(RegisterExtensionType(smallint)); + + CastOptions options; + options.allow_int_overflow = false; + + std::shared_ptr result; + std::vector is_valid = {true, false, true, true, true}; + + // Smallint(int16) to uint8, no overflow/underrun + auto v1 = SmallintArrayFromJSON("[0, 100, 200, 1, 2]"); + auto e1 = ArrayFromJSON(uint8(), "[0, 100, 200, 1, 2]"); + CheckPass(*v1, *e1, uint8(), options); + + // Smallint(int16) to uint8, with overflow + auto v2 = SmallintArrayFromJSON("[0, null, 256, 1, 3]"); + auto e2 = ArrayFromJSON(uint8(), "[0, null, 0, 1, 3]"); + // allow overflow + options.allow_int_overflow = true; + CheckPass(*v2, *e2, uint8(), options); + // disallow overflow + options.allow_int_overflow = false; + ASSERT_RAISES(Invalid, Cast(&ctx_, *v2, uint8(), options, &result)); + + // Smallint(int16) to uint8, with underflow + auto v3 = SmallintArrayFromJSON("[0, null, -1, 1, 0]"); + auto e3 = ArrayFromJSON(uint8(), "[0, null, 255, 1, 0]"); + // allow overflow + options.allow_int_overflow = true; + CheckPass(*v3, *e3, uint8(), options); + // disallow overflow + options.allow_int_overflow = false; + ASSERT_RAISES(Invalid, Cast(&ctx_, *v3, uint8(), options, &result)); + + ASSERT_OK(UnregisterExtensionType("smallint")); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/testing/extension_type.h b/cpp/src/arrow/testing/extension_type.h index d29ddccdc5d..0384a28f343 100644 --- a/cpp/src/arrow/testing/extension_type.h +++ b/cpp/src/arrow/testing/extension_type.h @@ -48,10 +48,38 @@ class ARROW_EXPORT UUIDType : public ExtensionType { std::string Serialize() const override { return "uuid-type-unique-code"; } }; +class ARROW_EXPORT SmallintArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +class ARROW_EXPORT SmallintType : public ExtensionType { + public: + SmallintType() : ExtensionType(int16()) {} + + std::string extension_name() const override { return "smallint"; } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + Status Deserialize(std::shared_ptr storage_type, + const std::string& serialized, + std::shared_ptr* out) const override; + + std::string Serialize() const override { return "smallint"; } +}; + ARROW_EXPORT std::shared_ptr uuid(); +ARROW_EXPORT +std::shared_ptr smallint(); + ARROW_EXPORT std::shared_ptr ExampleUUID(); +ARROW_EXPORT +std::shared_ptr ExampleSmallint(); + } // namespace arrow diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 8caf3f1cec9..009ee80b8a7 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -382,11 +382,7 @@ void SleepFor(double seconds) { // Extension types bool UUIDType::ExtensionEquals(const ExtensionType& other) const { - const auto& other_ext = static_cast(other); - if (other_ext.extension_name() != this->extension_name()) { - return false; - } - return true; + return (other.extension_name() == this->extension_name()); } std::shared_ptr UUIDType::MakeArray(std::shared_ptr data) const { @@ -423,4 +419,38 @@ std::shared_ptr ExampleUUID() { return MakeArray(ext_data); } +bool SmallintType::ExtensionEquals(const ExtensionType& other) const { + return (other.extension_name() == this->extension_name()); +} + +std::shared_ptr SmallintType::MakeArray(std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("smallint", static_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +Status SmallintType::Deserialize(std::shared_ptr storage_type, + const std::string& serialized, + std::shared_ptr* out) const { + if (serialized != "smallint") { + return Status::Invalid("Type identifier did not match"); + } + if (!storage_type->Equals(*int16())) { + return Status::Invalid("Invalid storage type for SmallintType"); + } + *out = std::make_shared(); + return Status::OK(); +} + +std::shared_ptr smallint() { return std::make_shared(); } + +std::shared_ptr ExampleSmallint() { + auto storage_type = int16(); + auto ext_type = smallint(); + auto arr = ArrayFromJSON(storage_type, "[-32768, null, 1, 2, 3, 4, 32767]"); + auto ext_data = arr->data()->Copy(); + ext_data->type = ext_type; + return MakeArray(ext_data); +} + } // namespace arrow From ec0c425f6337d8b23608ec56de60f92bb7626a9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 17 Mar 2020 21:01:50 +0100 Subject: [PATCH 4/9] Validate extension type's name and storage type --- cpp/src/arrow/compute/kernels/cast.cc | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 7cdec94f7ed..9a1759e5945 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -1135,7 +1135,7 @@ class ExtensionCastKernel : public CastKernelBase { if (in_type.id() != Type::EXTENSION) { return Status::TypeError("Not an extension type"); } - auto storage_type = checked_cast(in_type).storage_type(); + const auto storage_type = checked_cast(in_type).storage_type(); std::unique_ptr storage_caster; RETURN_NOT_OK(GetCastFunction(*storage_type, out_type, options, &storage_caster)); @@ -1145,13 +1145,31 @@ class ExtensionCastKernel : public CastKernelBase { return Status::OK(); } + Status Init(const DataType& in_type) override { + auto& type = checked_cast(in_type); + storage_type_ = type.storage_type(); + extension_name_ = type.extension_name(); + return Status::OK(); + } + Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override { DCHECK_EQ(input.kind(), Datum::ARRAY); if (input.type()->id() != Type::EXTENSION) { return Status::TypeError("Not an extension type"); } + + // validate: type is the same as the type the kernel was constructed with + auto& input_type = checked_cast(*input.type()); + if (input_type.extension_name() != extension_name_) { + return Status::TypeError("EEEE"); + } + if (!input_type.storage_type()->Equals(storage_type_)) { + return Status::TypeError("FFF"); + } + + // construct an ArrayData object with the underlying storage type auto new_input = input.array()->Copy(); - new_input->type = checked_cast(*new_input->type).storage_type(); + new_input->type = storage_type_; return InvokeWithAllocation(ctx, storage_caster_.get(), new_input, out); } @@ -1160,6 +1178,8 @@ class ExtensionCastKernel : public CastKernelBase { std::shared_ptr out_type) : CastKernelBase(std::move(out_type)), storage_caster_(std::move(storage_caster)) {} + std::string extension_name_; + std::shared_ptr storage_type_; std::unique_ptr storage_caster_; }; From 26750bec8412a689c53be6c12cc01a56336c2a4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 17 Mar 2020 21:05:40 +0100 Subject: [PATCH 5/9] Better validation messages --- cpp/src/arrow/compute/kernels/cast.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 9a1759e5945..72b39c86f80 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -1161,10 +1161,12 @@ class ExtensionCastKernel : public CastKernelBase { // validate: type is the same as the type the kernel was constructed with auto& input_type = checked_cast(*input.type()); if (input_type.extension_name() != extension_name_) { - return Status::TypeError("EEEE"); + return Status::TypeError( + "The cast kernel was constructed with a differently named extension type"); } if (!input_type.storage_type()->Equals(storage_type_)) { - return Status::TypeError("FFF"); + return Status::TypeError( + "The cast kernel was constructed with a different extension type"); } // construct an ArrayData object with the underlying storage type From bfdde93a3ca21328ada03c0dae50aad73c2b72e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 17 Mar 2020 21:48:52 +0100 Subject: [PATCH 6/9] Apply suggestions from code review Co-Authored-By: Benjamin Kietzman --- cpp/src/arrow/compute/kernels/cast.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 72b39c86f80..641d36ef9d6 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -1159,10 +1159,11 @@ class ExtensionCastKernel : public CastKernelBase { } // validate: type is the same as the type the kernel was constructed with - auto& input_type = checked_cast(*input.type()); + const auto& input_type = checked_cast(*input.type()); if (input_type.extension_name() != extension_name_) { return Status::TypeError( - "The cast kernel was constructed with a differently named extension type"); + "The cast kernel was constructed to cast from the extension type named '", extension_name_, + "' but input has named extension type name '", input_type.extension_name(), "'"); } if (!input_type.storage_type()->Equals(storage_type_)) { return Status::TypeError( From 957f191d71c9ee8e6f29ec66d153823e3b0eb2f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 17 Mar 2020 22:10:46 +0100 Subject: [PATCH 7/9] lint --- cpp/src/arrow/compute/kernels/cast.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 641d36ef9d6..e35741e80f2 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -1162,8 +1162,9 @@ class ExtensionCastKernel : public CastKernelBase { const auto& input_type = checked_cast(*input.type()); if (input_type.extension_name() != extension_name_) { return Status::TypeError( - "The cast kernel was constructed to cast from the extension type named '", extension_name_, - "' but input has named extension type name '", input_type.extension_name(), "'"); + "The cast kernel was constructed to cast from the extension type named '", + extension_name_, "' but input has extension type named '", + input_type.extension_name(), "'"); } if (!input_type.storage_type()->Equals(storage_type_)) { return Status::TypeError( From fcde06d5f98090d678fddd99668097da12767179 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 18 Mar 2020 11:11:08 +0100 Subject: [PATCH 8/9] zero copy tests and better error messages --- cpp/src/arrow/compute/kernels/cast.cc | 12 ++++-------- cpp/src/arrow/compute/kernels/cast_test.cc | 4 ++++ python/pyarrow/tests/test_extension_type.py | 5 +++++ 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index e35741e80f2..bdac0225502 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -1132,9 +1132,6 @@ class ExtensionCastKernel : public CastKernelBase { static Status Make(const DataType& in_type, std::shared_ptr out_type, const CastOptions& options, std::unique_ptr* kernel) { - if (in_type.id() != Type::EXTENSION) { - return Status::TypeError("Not an extension type"); - } const auto storage_type = checked_cast(in_type).storage_type(); std::unique_ptr storage_caster; @@ -1154,9 +1151,6 @@ class ExtensionCastKernel : public CastKernelBase { Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override { DCHECK_EQ(input.kind(), Datum::ARRAY); - if (input.type()->id() != Type::EXTENSION) { - return Status::TypeError("Not an extension type"); - } // validate: type is the same as the type the kernel was constructed with const auto& input_type = checked_cast(*input.type()); @@ -1167,8 +1161,10 @@ class ExtensionCastKernel : public CastKernelBase { input_type.extension_name(), "'"); } if (!input_type.storage_type()->Equals(storage_type_)) { - return Status::TypeError( - "The cast kernel was constructed with a different extension type"); + return Status::TypeError("The cast kernel was constructed with a storage type: ", + storage_type_->ToString(), + ", but it is called with a different storage type:", + input_type.storage_type()->ToString()); } // construct an ArrayData object with the underlying storage type diff --git a/cpp/src/arrow/compute/kernels/cast_test.cc b/cpp/src/arrow/compute/kernels/cast_test.cc index bc01d1997c8..65ae570f50c 100644 --- a/cpp/src/arrow/compute/kernels/cast_test.cc +++ b/cpp/src/arrow/compute/kernels/cast_test.cc @@ -1499,6 +1499,10 @@ TEST_F(TestCast, ExtensionTypeToIntDowncast) { std::shared_ptr result; std::vector is_valid = {true, false, true, true, true}; + // Smallint(int16) to int16 + auto v0 = SmallintArrayFromJSON("[0, 100, 200, 1, 2]"); + CheckZeroCopy(*v0, int16()); + // Smallint(int16) to uint8, no overflow/underrun auto v1 = SmallintArrayFromJSON("[0, 100, 200, 1, 2]"); auto e1 = ArrayFromJSON(uint8(), "[0, 100, 200, 1, 2]"); diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 05fabe4587b..0d07ac143ea 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -182,6 +182,11 @@ def test_cast_kernel_on_extension_arrays(): storage = pa.array([1, 2, 3, 4], pa.int64()) arr = pa.ExtensionArray.from_storage(IntegerType(), storage) + # test that no allocation happens during identity cast + allocated_before_cast = pa.total_allocated_bytes() + casted = arr.cast(pa.int64()) + assert pa.total_allocated_bytes() == allocated_before_cast + cases = [ (pa.int64(), pa.Int64Array), (pa.int32(), pa.Int32Array), From e03c83fe769520b1eba7c079e84fcab88d5b715d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 18 Mar 2020 11:22:23 +0100 Subject: [PATCH 9/9] iwyu --- cpp/src/arrow/compute/kernels/cast.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index bdac0225502..48670bf0452 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include