diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index b984bc10425..da79ab92b5a 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -155,6 +155,7 @@ set(ARROW_SRCS array/diff.cc array/util.cc array/validate.cc + extensions/complex_type.cc builder.cc buffer.cc chunked_array.cc diff --git a/cpp/src/arrow/extension_type_test.cc b/cpp/src/arrow/extension_type_test.cc index 31222d74806..17a6c139ad4 100644 --- a/cpp/src/arrow/extension_type_test.cc +++ b/cpp/src/arrow/extension_type_test.cc @@ -27,6 +27,7 @@ #include "arrow/array/array_nested.h" #include "arrow/array/util.h" #include "arrow/extension_type.h" +#include "arrow/extensions/complex_type.h" #include "arrow/io/memory.h" #include "arrow/ipc/options.h" #include "arrow/ipc/reader.h" @@ -178,7 +179,9 @@ class ExtStructType : public ExtensionType { class TestExtensionType : public ::testing::Test { public: - void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared())); } + void SetUp() { + ASSERT_OK(RegisterExtensionType(std::make_shared())); + } void TearDown() { if (GetExtensionType("uuid")) { @@ -187,6 +190,23 @@ class TestExtensionType : public ::testing::Test { } }; +TEST_F(TestExtensionType, ComplexTypeTest) { + auto registered_type = GetExtensionType("arrow.complex64"); + ASSERT_NE(registered_type, nullptr); + + auto type = complex64(); + ASSERT_EQ(type->id(), Type::EXTENSION); + + const auto& ext_type = static_cast(*type); + std::string serialized = ext_type.Serialize(); + + ASSERT_OK_AND_ASSIGN(auto deserialized, + ext_type.Deserialize(fixed_size_list(float32(), 2), serialized)); + + ASSERT_TRUE(deserialized->Equals(*type)); + ASSERT_FALSE(deserialized->Equals(*fixed_size_list(float32(), 2))); +} + TEST_F(TestExtensionType, ExtensionTypeTest) { auto type_not_exist = GetExtensionType("uuid-unknown"); ASSERT_EQ(type_not_exist, nullptr); diff --git a/cpp/src/arrow/extensions/complex_type.cc b/cpp/src/arrow/extensions/complex_type.cc new file mode 100644 index 00000000000..8cf89edcd7e --- /dev/null +++ b/cpp/src/arrow/extensions/complex_type.cc @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Complex Number Extension Type + +#include +#include +#include + +#include "arrow/extensions/complex_type.h" + +namespace arrow { + +bool ComplexFloatType::ExtensionEquals(const ExtensionType& other) const { + const auto& other_ext = static_cast(other); + return other_ext.extension_name() == this->extension_name(); +} + +bool ComplexDoubleType::ExtensionEquals(const ExtensionType& other) const { + const auto& other_ext = static_cast(other); + return other_ext.extension_name() == this->extension_name(); +} + + +std::shared_ptr complex64() { + return std::make_shared(); +} + +std::shared_ptr complex128() { + return std::make_shared(); +} + +/// NOTE(sjperkins) +// Suggestions on how to improve this welcome! +std::once_flag complex_float_registered; +std::once_flag complex_double_registered; + +Status register_complex_types() +{ + std::call_once(complex_float_registered, + RegisterExtensionType, + std::make_shared()); + + std::call_once(complex_double_registered, + RegisterExtensionType, + std::make_shared()); + + return Status::OK(); +} + +static Status complex_types_registered = register_complex_types(); + +}; // namespace arrow diff --git a/cpp/src/arrow/extensions/complex_type.h b/cpp/src/arrow/extensions/complex_type.h new file mode 100644 index 00000000000..9a425a210ab --- /dev/null +++ b/cpp/src/arrow/extensions/complex_type.h @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Complex Number Extension Type +#pragma once + +#include + +#include "arrow/extension_type.h" + +namespace arrow { + + +std::shared_ptr complex64(); +std::shared_ptr complex128(); + + +class ComplexFloatArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +class ComplexFloatType : public ExtensionType { + public: + using c_type = std::complex; + + explicit ComplexFloatType() + : ExtensionType(fixed_size_list(float32(), 2)) {} + + std::string name() const override { + return "complex64"; + } + + std::string extension_name() const override { + return "arrow.complex64"; + } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::shared_ptr MakeArray(std::shared_ptr data) const override { + return std::make_shared(data); + } + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized) const override { + return complex64(); + }; + + std::string Serialize() const override { + return ""; + } +}; + + +class ComplexDoubleArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +class ComplexDoubleType : public ExtensionType { + public: + using c_type = std::complex; + + explicit ComplexDoubleType() + : ExtensionType(fixed_size_list(float64(), 2)) {} + + std::string name() const override { + return "complex128"; + } + + std::string extension_name() const override { + return "arrow.complex128"; + } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::shared_ptr MakeArray(std::shared_ptr data) const override { + return std::make_shared(data); + } + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized) const override { + return complex128(); + }; + + std::string Serialize() const override { + return ""; + } +}; + +}; // namespace arrow diff --git a/cpp/src/arrow/python/CMakeLists.txt b/cpp/src/arrow/python/CMakeLists.txt index 835eacad1f6..078c22a2f83 100644 --- a/cpp/src/arrow/python/CMakeLists.txt +++ b/cpp/src/arrow/python/CMakeLists.txt @@ -44,7 +44,8 @@ set(ARROW_PYTHON_SRCS numpy_to_arrow.cc python_to_arrow.cc pyarrow.cc - serialize.cc) + serialize.cc + type_traits.cc) set_source_files_properties(init.cc PROPERTIES SKIP_PRECOMPILE_HEADERS ON SKIP_UNITY_BUILD_INCLUSION ON) diff --git a/cpp/src/arrow/python/arrow_to_pandas.cc b/cpp/src/arrow/python/arrow_to_pandas.cc index 41e537191a3..4cdcb025f3e 100644 --- a/cpp/src/arrow/python/arrow_to_pandas.cc +++ b/cpp/src/arrow/python/arrow_to_pandas.cc @@ -17,6 +17,7 @@ // Functions for pandas conversion via NumPy +#include "arrow/extensions/complex_type.h" #include "arrow/python/arrow_to_pandas.h" #include "arrow/python/numpy_interop.h" // IWYU pragma: expand @@ -40,6 +41,7 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/hashing.h" #include "arrow/util/int_util.h" +#include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/parallel.h" @@ -331,6 +333,8 @@ class PandasWriter { HALF_FLOAT, FLOAT, DOUBLE, + COMPLEX_FLOAT, + COMPLEX_DOUBLE, BOOL, DATETIME_DAY, DATETIME_SECOND, @@ -915,6 +919,15 @@ inline void ConvertNumericNullableCast(const ChunkedArray& data, InType na_value } } +template +inline void ConvertNumericNullableComplex(const ChunkedArray& data, + OutType* out_values) { + for (int c = 0; c < data.num_chunks(); c++) { + const auto& arr = *data.chunk(c); + arr.num_fields(); + } +} + template class TypedPandasWriter : public PandasWriter { public: @@ -1234,6 +1247,7 @@ class IntWriter : public TypedPandasWriter { } }; + template class FloatWriter : public TypedPandasWriter { public: @@ -1242,7 +1256,7 @@ class FloatWriter : public TypedPandasWriter { using T = typename ArrowType::c_type; bool CanZeroCopy(const ChunkedArray& data) const override { - return IsNonNullContiguous(data) && data.type()->id() == ArrowType::type_id; + return data.type()->id() == ArrowType::type_id && IsNonNullContiguous(data); } Status CopyInto(std::shared_ptr data, int64_t rel_placement) override { @@ -1278,6 +1292,26 @@ class FloatWriter : public TypedPandasWriter { case Type::DOUBLE: ConvertNumericNullableCast(*data, npy_traits::na_sentinel, out_values); break; + case Type::EXTENSION: + { + auto ext_type = std::static_pointer_cast(data->type()); + + if(ext_type == nullptr) { + return Status::TypeError( + "Unable to cast ", data->type()->ToString(), "to ExtensionType"); + } + + if(ext_type->extension_name() == "arrow.complex64") { + ConvertNumericNullableComplex(*data, out_values); + } else if (ext_type->extension_name() == "arrow.complex128") { + ConvertNumericNullableComplex(*data, out_values); + } else { + return Status::NotImplemented("Cannot write Arrow data of type ", + data->type()->ToString(), + " to a Pandas floating point block"); + } + } + break; default: return Status::NotImplemented("Cannot write Arrow data of type ", data->type()->ToString(), @@ -1290,6 +1324,82 @@ class FloatWriter : public TypedPandasWriter { } }; + +template +class ComplexWriter : public TypedPandasWriter { + public: + using ArrowType = typename npy_traits::TypeClass; + using TypedPandasWriter::TypedPandasWriter; + using T = typename ArrowType::c_type; + + bool CanZeroCopy(const ChunkedArray& data) const override { + return data.type()->id() == ArrowType::type_id && IsNonNullContiguous(data); + } + + Status CopyInto(std::shared_ptr data, int64_t rel_placement) override { + Type::type in_type = data->type()->id(); + auto out_values = this->GetBlockColumnStart(rel_placement); + +#define INTEGER_CASE(IN_TYPE) \ + ConvertIntegerWithNulls(this->options_, *data, out_values); \ + break; + + switch (in_type) { + case Type::UINT8: + INTEGER_CASE(uint8_t); + case Type::INT8: + INTEGER_CASE(int8_t); + case Type::UINT16: + INTEGER_CASE(uint16_t); + case Type::INT16: + INTEGER_CASE(int16_t); + case Type::UINT32: + INTEGER_CASE(uint32_t); + case Type::INT32: + INTEGER_CASE(int32_t); + case Type::UINT64: + INTEGER_CASE(uint64_t); + case Type::INT64: + INTEGER_CASE(int64_t); + case Type::HALF_FLOAT: + ConvertNumericNullableCast(*data, npy_traits::na_sentinel, out_values); + case Type::FLOAT: + ConvertNumericNullableCast(*data, npy_traits::na_sentinel, out_values); + break; + case Type::DOUBLE: + ConvertNumericNullableCast(*data, npy_traits::na_sentinel, out_values); + break; + case Type::EXTENSION: + { + auto ext_type = std::static_pointer_cast(data->type()); + + if(ext_type == nullptr) { + return Status::TypeError( + "Unable to cast ", data->type()->ToString(), "to ExtensionType"); + } + + if(ext_type->extension_name() == "arrow.complex64") { + ConvertNumericNullableComplex(*data, out_values); + } else if (ext_type->extension_name() == "arrow.complex128") { + ConvertNumericNullableComplex(*data, out_values); + } else { + return Status::NotImplemented("Cannot write Arrow data of type ", + data->type()->ToString(), + " to a Pandas complex number block"); + } + } + break; + default: + return Status::NotImplemented("Cannot write Arrow data of type ", + data->type()->ToString(), + " to a Pandas complex number block"); + } +#undef INTEGER_CASE + + return Status::OK(); + } +}; + using UInt8Writer = IntWriter; using Int8Writer = IntWriter; using UInt16Writer = IntWriter; @@ -1301,6 +1411,8 @@ using Int64Writer = IntWriter; using Float16Writer = FloatWriter; using Float32Writer = FloatWriter; using Float64Writer = FloatWriter; +using Complex64Writer = ComplexWriter; +using Complex128Writer = ComplexWriter; class BoolWriter : public TypedPandasWriter { public: @@ -1501,7 +1613,7 @@ class TimedeltaWriter : public TypedPandasWriter { bool CanZeroCopy(const ChunkedArray& data) const override { const auto& type = checked_cast(*data.type()); - return IsNonNullContiguous(data) && type.unit() == UNIT; + return type.unit() == UNIT && IsNonNullContiguous(data); } Status CopyInto(std::shared_ptr data, int64_t rel_placement) override { @@ -1824,6 +1936,8 @@ Status MakeWriter(const PandasOptions& options, PandasWriter::type writer_type, BLOCK_CASE(HALF_FLOAT, Float16Writer); BLOCK_CASE(FLOAT, Float32Writer); BLOCK_CASE(DOUBLE, Float64Writer); + BLOCK_CASE(COMPLEX_FLOAT, Complex64Writer); + BLOCK_CASE(COMPLEX_DOUBLE, Complex128Writer); BLOCK_CASE(BOOL, BoolWriter); BLOCK_CASE(DATETIME_DAY, DatetimeDayWriter); BLOCK_CASE(DATETIME_SECOND, DatetimeSecondWriter); @@ -1848,7 +1962,8 @@ Status MakeWriter(const PandasOptions& options, PandasWriter::type writer_type, return Status::OK(); } -static Status GetPandasWriterType(const ChunkedArray& data, const PandasOptions& options, +static Status GetPandasWriterType(const ChunkedArray& data, + const PandasOptions& options, PandasWriter::type* output_type) { #define INTEGER_CASE(NAME) \ *output_type = \ @@ -1972,6 +2087,16 @@ static Status GetPandasWriterType(const ChunkedArray& data, const PandasOptions& *output_type = PandasWriter::CATEGORICAL; break; case Type::EXTENSION: + { + auto ext_type = std::static_pointer_cast(data.type()); + + if(ext_type->extension_name() == "arrow.complex64") { + *output_type = PandasWriter::COMPLEX_FLOAT; + } else if (ext_type->extension_name() == "arrow.complex128") { + *output_type = PandasWriter::COMPLEX_DOUBLE; + } + } + *output_type = PandasWriter::EXTENSION; break; default: @@ -2057,7 +2182,8 @@ class ConsolidatedBlockCreator : public PandasBlockCreator { *out = PandasWriter::EXTENSION; return Status::OK(); } else { - return GetPandasWriterType(*arrays_[column_index], options_, out); + return GetPandasWriterType(*arrays_[column_index], + options_, out); } } @@ -2163,7 +2289,8 @@ class SplitBlockCreator : public PandasBlockCreator { output_type = PandasWriter::EXTENSION; } else { // Null count needed to determine output type - RETURN_NOT_OK(GetPandasWriterType(*arrays_[i], options_, &output_type)); + RETURN_NOT_OK(GetPandasWriterType(*arrays_[i], + options_, &output_type)); } return MakeWriter(this->options_, output_type, type, num_rows_, 1, writer); } @@ -2197,6 +2324,44 @@ class SplitBlockCreator : public PandasBlockCreator { std::vector> writers_; }; + +Status ConvertComplexArrays(const PandasOptions& options, + ChunkedArrayVector* arrays, + FieldVector* fields, + PandasOptions* modified_options) { + + for (int i = 0; i < static_cast(arrays->size()); i++) { + auto array = (*arrays)[i]; + auto field = (*fields)[i]; + + if (array->type()->id() == Type::EXTENSION) { + auto ext = std::static_pointer_cast(array->type()); + bool is_f32 = ext->extension_name() == "arrow.complex64"; + bool is_f64 = !is_f32 && ext->extension_name() == "arrow.complex128"; + + if(is_f32 || is_f64) { + ArrayVector chunks; + + for(int c=0; c < array->num_chunks(); ++c) { + auto ext = std::static_pointer_cast(array->chunk(c)); + auto storage = std::static_pointer_cast(ext->storage()); + chunks.push_back(storage->Flatten().ValueOrDie()); + } + + auto dtype = is_f32 ? float32() : float64(); + auto meta = key_value_metadata({"__complex_field_marker__"}, {"true"}); + + (*arrays)[i] = std::make_shared(chunks, dtype); + (*fields)[i] = field->WithType(dtype)->WithMergedMetadata(meta); + modified_options->extension_columns.erase(field->name()); + } + } + } + + return Status::OK(); +} + + Status ConvertCategoricals(const PandasOptions& options, ChunkedArrayVector* arrays, FieldVector* fields) { std::vector columns_to_encode; @@ -2300,10 +2465,11 @@ Status ConvertTableToPandas(const PandasOptions& options, std::shared_ptr // ARROW-3789: allow "self-destructing" by releasing references to columns as // we convert them to pandas table = nullptr; + PandasOptions modified_options = options; RETURN_NOT_OK(ConvertCategoricals(options, &arrays, &fields)); + RETURN_NOT_OK(ConvertComplexArrays(options, &arrays, &fields, &modified_options)); - PandasOptions modified_options = options; modified_options.strings_to_categorical = false; modified_options.categorical_columns.clear(); diff --git a/cpp/src/arrow/python/numpy_convert.cc b/cpp/src/arrow/python/numpy_convert.cc index 49706807644..20427a7ddb5 100644 --- a/cpp/src/arrow/python/numpy_convert.cc +++ b/cpp/src/arrow/python/numpy_convert.cc @@ -30,6 +30,8 @@ #include "arrow/type.h" #include "arrow/util/logging.h" +#include "arrow/extensions/complex_type.h" + #include "arrow/python/common.h" #include "arrow/python/pyarrow.h" #include "arrow/python/type_traits.h" @@ -84,6 +86,12 @@ Status GetTensorType(PyObject* dtype, std::shared_ptr* out) { TO_ARROW_TYPE_CASE(FLOAT16, float16); TO_ARROW_TYPE_CASE(FLOAT32, float32); TO_ARROW_TYPE_CASE(FLOAT64, float64); + case NPY_COMPLEX64: + *out = complex64(); + break; + case NPY_COMPLEX128: + *out = complex128(); + break; default: { return Status::NotImplemented("Unsupported numpy type ", descr->type_num); } @@ -109,6 +117,21 @@ Status GetNumPyType(const DataType& type, int* type_num) { NUMPY_TYPE_CASE(HALF_FLOAT, FLOAT16); NUMPY_TYPE_CASE(FLOAT, FLOAT32); NUMPY_TYPE_CASE(DOUBLE, FLOAT64); + case Type::EXTENSION: { + auto ext = static_cast(&type); + + if (ext->extension_name() == "arrow.complex64") { + *type_num = NPY_COMPLEX64; + break; + } else if (ext->extension_name() == "arrow.complex128") { + *type_num = NPY_COMPLEX128; + break; + } else { + return Status::NotImplemented("Unsupported ExtensionType: ", + ext->extension_name()); + } + } + default: { return Status::NotImplemented("Unsupported tensor type: ", type.ToString()); } @@ -144,6 +167,8 @@ Status NumPyDtypeToArrow(PyArray_Descr* descr, std::shared_ptr* out) { TO_ARROW_TYPE_CASE(FLOAT16, float16); TO_ARROW_TYPE_CASE(FLOAT32, float32); TO_ARROW_TYPE_CASE(FLOAT64, float64); + TO_ARROW_TYPE_CASE(COMPLEX64, complex64); + TO_ARROW_TYPE_CASE(COMPLEX128, complex128); TO_ARROW_TYPE_CASE(STRING, binary); TO_ARROW_TYPE_CASE(UNICODE, utf8); case NPY_DATETIME: { diff --git a/cpp/src/arrow/python/numpy_internal.h b/cpp/src/arrow/python/numpy_internal.h index 973f577cb13..cb4eacb0297 100644 --- a/cpp/src/arrow/python/numpy_internal.h +++ b/cpp/src/arrow/python/numpy_internal.h @@ -102,6 +102,8 @@ static inline std::string GetNumPyTypeName(int npy_type) { TYPE_CASE(FLOAT16, "float16") TYPE_CASE(FLOAT32, "float32") TYPE_CASE(FLOAT64, "float64") + TYPE_CASE(COMPLEX64, "complex64") + TYPE_CASE(COMPLEX128, "complex128") TYPE_CASE(DATETIME, "datetime64") TYPE_CASE(TIMEDELTA, "timedelta64") TYPE_CASE(OBJECT, "object") @@ -143,6 +145,8 @@ inline Status VisitNumpyArrayInline(PyArrayObject* arr, VISITOR* visitor) { TYPE_VISIT_INLINE(FLOAT16); TYPE_VISIT_INLINE(FLOAT32); TYPE_VISIT_INLINE(FLOAT64); + TYPE_VISIT_INLINE(COMPLEX64); + TYPE_VISIT_INLINE(COMPLEX128); TYPE_VISIT_INLINE(DATETIME); TYPE_VISIT_INLINE(TIMEDELTA); TYPE_VISIT_INLINE(OBJECT); diff --git a/cpp/src/arrow/python/numpy_to_arrow.cc b/cpp/src/arrow/python/numpy_to_arrow.cc index b6121846206..fdba1e2ff64 100644 --- a/cpp/src/arrow/python/numpy_to_arrow.cc +++ b/cpp/src/arrow/python/numpy_to_arrow.cc @@ -232,6 +232,8 @@ class NumPyConverter { Status Visit(const FixedSizeBinaryType& type); + Status Visit(const ExtensionType& type); + // Default case Status Visit(const DataType& type) { return TypeNotImplemented(type.ToString()); } @@ -468,6 +470,7 @@ inline Status NumPyConverter::ConvertData(std::shared_ptr* data) { return Status::OK(); } + template <> inline Status NumPyConverter::ConvertData(std::shared_ptr* data) { std::shared_ptr input_type; @@ -585,6 +588,43 @@ Status NumPyConverter::Visit(const BinaryType& type) { return Status::OK(); } + +Status NumPyConverter::Visit(const ExtensionType& type) { + if(type.extension_name() == "arrow.complex64") { + if (mask_ != nullptr) { + RETURN_NOT_OK(InitNullBitmap()); + null_count_ = MaskToBitmap(mask_, length_, null_bitmap_data_); + } else { + RETURN_NOT_OK(NumPyNullsConverter::Convert(pool_, arr_, from_pandas_, &null_bitmap_, + &null_count_)); + } + + std::shared_ptr data; + RETURN_NOT_OK(ConvertData(&data)); + + auto float_arr_data = ArrayData::Make(float32(), length_*2, {nullptr, data}, 0, 0); + auto arr_data = ArrayData::Make(type_, length_, {null_bitmap_}, {float_arr_data}, null_count_, 0); + return PushArray(arr_data); + } else if(type.extension_name() == "arrow.complex128") { + if (mask_ != nullptr) { + RETURN_NOT_OK(InitNullBitmap()); + null_count_ = MaskToBitmap(mask_, length_, null_bitmap_data_); + } else { + RETURN_NOT_OK(NumPyNullsConverter::Convert(pool_, arr_, from_pandas_, &null_bitmap_, + &null_count_)); + } + + std::shared_ptr data; + RETURN_NOT_OK(ConvertData(&data)); + + auto float_arr_data = ArrayData::Make(float64(), length_*2, {nullptr, data}, 0, 0); + auto arr_data = ArrayData::Make(type_, length_, {null_bitmap_}, {float_arr_data}, null_count_, 0); + return PushArray(arr_data); + } else { + return TypeNotImplemented(type.ToString()); + } +} + Status NumPyConverter::Visit(const FixedSizeBinaryType& type) { auto byte_width = type.byte_width(); diff --git a/cpp/src/arrow/python/python_to_arrow.cc b/cpp/src/arrow/python/python_to_arrow.cc index 21ced0898ef..ec8c8e7f815 100644 --- a/cpp/src/arrow/python/python_to_arrow.cc +++ b/cpp/src/arrow/python/python_to_arrow.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -35,6 +36,7 @@ #include "arrow/array/builder_primitive.h" #include "arrow/array/builder_time.h" #include "arrow/chunked_array.h" +#include "arrow/extensions/complex_type.h" #include "arrow/status.h" #include "arrow/type.h" #include "arrow/type_traits.h" @@ -256,6 +258,34 @@ class PyValue { return value; } + static Result> Convert(const ComplexFloatType*, const O&, I obj) { + std::complex value; + + if (PyComplex_Check(obj)) { + value = + std::complex(PyComplex_RealAsDouble(obj), PyComplex_ImagAsDouble(obj)); + RETURN_IF_PYERROR(); + } else { + return internal::InvalidValue(obj, "tried to convert to std::complex"); + } + + return value; + } + +static Result> Convert(const ComplexDoubleType*, const O&, I obj) { + std::complex value; + + if (PyComplex_Check(obj)) { + value = + std::complex(PyComplex_RealAsDouble(obj), PyComplex_ImagAsDouble(obj)); + RETURN_IF_PYERROR(); + } else { + return internal::InvalidValue(obj, "tried to convert to std::complex"); + } + + return value; + } + static Result Convert(const Decimal128Type* type, const O&, I obj) { Decimal128 value; RETURN_NOT_OK(internal::DecimalFromPyObject(obj, *type, &value)); diff --git a/cpp/src/arrow/python/type_traits.cc b/cpp/src/arrow/python/type_traits.cc new file mode 100644 index 00000000000..8558af24ad9 --- /dev/null +++ b/cpp/src/arrow/python/type_traits.cc @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/python/type_traits.h" + +namespace arrow { +namespace py { +namespace internal { + +constexpr std::complex npy_traits::na_sentinel; +constexpr std::complex npy_traits::na_sentinel; + +} // namespace internal +} // namespace py +} // namespace arrow diff --git a/cpp/src/arrow/python/type_traits.h b/cpp/src/arrow/python/type_traits.h index a941577f765..da39ff0e0e5 100644 --- a/cpp/src/arrow/python/type_traits.h +++ b/cpp/src/arrow/python/type_traits.h @@ -22,6 +22,7 @@ #include "arrow/python/platform.h" #include +#include #include #include "arrow/python/numpy_interop.h" @@ -126,6 +127,48 @@ struct npy_traits { static inline bool isnull(double v) { return v != v; } }; +template +constexpr std::complex make_complex_nan() +{ +} + +template <> +struct npy_traits { + using TypeClass = ComplexFloatType; + // NOTE(sjperkins) + // This should technically be FixedSizeListScalar, but FixedSizeListScalar + // isn't correctly sized for memcpy's in numpy_to_arrow.cc and doesn't + // have a default constructor either + using value_type = std::complex; + + static constexpr std::complex na_sentinel = + std::complex( + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()); + + static constexpr bool supports_nulls = true; + static inline bool isnull(const std::complex & v) { return v != v; } +}; + +template <> +struct npy_traits { + using TypeClass = ComplexDoubleType; + // NOTE(sjperkins) + // This should technically be FixedSizeListScalar, but FixedSizeListScalar + // isn't correctly sized for memcpy's in numpy_to_arrow.cc and doesn't + // have a default constructor either + using value_type = std::complex; + + static constexpr std::complex na_sentinel = + std::complex( + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()); + + static constexpr bool supports_nulls = true; + static inline bool isnull(const std::complex & v) { return v != v; } +}; + + template <> struct npy_traits { typedef int64_t value_type; @@ -298,12 +341,12 @@ struct arrow_traits { static inline NPY_DATETIMEUNIT NumPyFrequency(TimeUnit::type unit) { switch (unit) { - case TimestampType::Unit::SECOND: + case TimeUnit::SECOND: return NPY_FR_s; - case TimestampType::Unit::MILLI: + case TimeUnit::MILLI: return NPY_FR_ms; break; - case TimestampType::Unit::MICRO: + case TimeUnit::MICRO: return NPY_FR_us; default: // NANO @@ -334,6 +377,10 @@ static inline int NumPyTypeSize(int npy_type) { return 4; case NPY_FLOAT64: return 8; + case NPY_COMPLEX64: + return 8; + case NPY_COMPLEX128: + return 16; case NPY_DATETIME: return 8; case NPY_OBJECT: diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 45afd7af2e6..cb793ea68ce 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -207,6 +207,12 @@ _NUMERIC_TYPE_DECL(HalfFloat) _NUMERIC_TYPE_DECL(Float) _NUMERIC_TYPE_DECL(Double) +class ComplexFloatType; +class ComplexFloatArray; + +class ComplexDoubleType; +class ComplexDoubleArray; + #undef _NUMERIC_TYPE_DECL enum class DateUnit : char { DAY = 0, MILLI = 1 }; @@ -441,6 +447,10 @@ std::shared_ptr ARROW_EXPORT float16(); std::shared_ptr ARROW_EXPORT float32(); /// \brief Return a DoubleType instance std::shared_ptr ARROW_EXPORT float64(); +/// \brief Return a ComplexFloatType instance +std::shared_ptr ARROW_EXPORT complex64(); +/// \brief Return a ComplexDoubleType instance +std::shared_ptr ARROW_EXPORT complex128(); /// \brief Return a StringType instance std::shared_ptr ARROW_EXPORT utf8(); /// \brief Return a LargeStringType instance diff --git a/python/pyarrow/pandas_compat.py b/python/pyarrow/pandas_compat.py index 687981ca5db..ad32c957642 100644 --- a/python/pyarrow/pandas_compat.py +++ b/python/pyarrow/pandas_compat.py @@ -63,6 +63,8 @@ def get_logical_type_map(): pa.lib.Type_BINARY: 'bytes', pa.lib.Type_FIXED_SIZE_BINARY: 'bytes', pa.lib.Type_STRING: 'unicode', + 'arrow.complex64': 'complex64', + 'arrow.complex128': 'complex128', }) return _logical_type_map @@ -81,6 +83,12 @@ def get_logical_type(arrow_type): return 'datetimetz' if arrow_type.tz is not None else 'datetime' elif isinstance(arrow_type, pa.lib.Decimal128Type): return 'decimal' + elif isinstance(arrow_type, pa.lib.BaseExtensionType): + try: + return logical_type_map[arrow_type.extension_name] + except KeyError: + pass + return 'object' @@ -96,6 +104,8 @@ def get_logical_type(arrow_type): np.uint64: 'uint64', np.float32: 'float32', np.float64: 'float64', + np.complex64: 'complex64', + np.complex128: 'complex128', 'datetime64[D]': 'date', np.unicode_: 'string', np.bytes_: 'bytes', @@ -743,11 +753,15 @@ def _reconstruct_block(item, columns=None, extension_columns=None): assert len(placement) == 1 name = columns[placement[0]] pandas_dtype = extension_columns[name] - if not hasattr(pandas_dtype, '__from_arrow__'): + + if pandas_dtype in {np.complex64, np.complex128}: + block = _int.make_block(arr, placement=placement) + elif not hasattr(pandas_dtype, '__from_arrow__'): raise ValueError("This column does not support to be converted " "to a pandas ExtensionArray") - pd_ext_arr = pandas_dtype.__from_arrow__(arr) - block = _int.make_block(pd_ext_arr, placement=placement) + else: + pd_ext_arr = pandas_dtype.__from_arrow__(arr) + block = _int.make_block(pd_ext_arr, placement=placement) else: block = _int.make_block(block_arr, placement=placement) @@ -793,11 +807,11 @@ def table_to_blockmanager(options, table, categories=None, # Set of the string repr of all numpy dtypes that can be stored in a pandas -# dataframe (complex not included since not supported by Arrow) +# dataframe _pandas_supported_numpy_types = { str(np.dtype(typ)) for typ in (np.sctypes['int'] + np.sctypes['uint'] + np.sctypes['float'] + - ['object', 'bool']) + ['object', 'bool'] + ['complex64', 'complex128']) } diff --git a/python/pyarrow/tests/test_pandas.py b/python/pyarrow/tests/test_pandas.py index f08461881fe..81cd17c0a19 100644 --- a/python/pyarrow/tests/test_pandas.py +++ b/python/pyarrow/tests/test_pandas.py @@ -67,6 +67,8 @@ def _alltypes_example(size=100): 'int64': np.arange(size, dtype=np.int64), 'float32': np.arange(size, dtype=np.float32), 'float64': np.arange(size, dtype=np.float64), + 'complex64': np.arange(size, dtype=np.complex64), + 'complex128': np.arange(size, dtype=np.complex128), 'bool': np.random.randn(size) > 0, # TODO(wesm): Pandas only support ns resolution, Arrow supports s, ms, # us, ns diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 32d70887aab..401e477b68c 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -52,6 +52,11 @@ cdef dict _pandas_type_map = { _Type_DECIMAL128: np.object_, } +cdef dict _pandas_ext_type_map = { + b'arrow.complex64': np.complex64, + b'arrow.complex128': np.complex128, +} + cdef dict _pep3118_type_map = { _Type_INT8: b'b', _Type_INT16: b'h', @@ -118,7 +123,7 @@ def _is_primitive(Type type): # Workaround for Cython parsing bug # https://github.com/cython/cython/issues/2143 ctypedef CFixedWidthType* _CFixedWidthTypePtr - +ctypedef const CExtensionType* _CExtensionTypePtr cdef class DataType(_Weakrefable): """ @@ -222,10 +227,22 @@ cdef class DataType(_Weakrefable): Return the equivalent NumPy / Pandas dtype. """ cdef Type type_id = self.type.id() - if type_id in _pandas_type_map: + cdef const CExtensionType * ext_type + cdef bytes ext_name + + if type_id == _Type_EXTENSION: + ext_type = dynamic_cast[_CExtensionTypePtr](self.type) + + if ext_type: + ext_name = ext_type.extension_name() + + if ext_name in _pandas_ext_type_map: + return _pandas_ext_type_map[ext_name] + + elif type_id in _pandas_type_map: return _pandas_type_map[type_id] - else: - raise NotImplementedError(str(self)) + + raise NotImplementedError(str(self)) def _export_to_c(self, out_ptr): """ @@ -738,6 +755,18 @@ cdef class BaseExtensionType(DataType): """ return pyarrow_wrap_data_type(self.ext_type.storage_type()) + + def __arrow_ext_class__(self): + """Return an extension array class to be used for building or + deserializing arrays with this extension type. + + This method should return a subclass of the ExtensionArray class. By + default, if not specialized in the extension implementation, an + extension type array will be a built-in ExtensionArray instance. + """ + return ExtensionArray + + def wrap_array(self, storage): """ Wrap the given storage array as an extension array. @@ -856,16 +885,6 @@ cdef class ExtensionType(BaseExtensionType): """ return NotImplementedError - def __arrow_ext_class__(self): - """Return an extension array class to be used for building or - deserializing arrays with this extension type. - - This method should return a subclass of the ExtensionArray class. By - default, if not specialized in the extension implementation, an - extension type array will be a built-in ExtensionArray instance. - """ - return ExtensionArray - cdef class PyExtensionType(ExtensionType): """