-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-3428: [Python] Fix from_pandas conversion from float to bool #2698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,6 +63,7 @@ namespace arrow { | |
|
|
||
| using internal::checked_cast; | ||
| using internal::CopyBitmap; | ||
| using internal::GenerateBitsUnrolled; | ||
|
|
||
| namespace py { | ||
|
|
||
|
|
@@ -246,6 +247,11 @@ class NumPyConverter { | |
| return Status::OK(); | ||
| } | ||
|
|
||
| // Called before ConvertData to ensure Numpy input buffer is in expected | ||
| // Arrow layout | ||
| template <typename ArrowType> | ||
| Status PrepareInputData(std::shared_ptr<Buffer>* data); | ||
|
|
||
| // ---------------------------------------------------------------------- | ||
| // Traditional visitor conversion for non-object arrays | ||
|
|
||
|
|
@@ -407,14 +413,32 @@ Status CopyStridedArray(PyArrayObject* arr, const int64_t length, MemoryPool* po | |
| } // namespace | ||
|
|
||
| template <typename ArrowType> | ||
| inline Status NumPyConverter::ConvertData(std::shared_ptr<Buffer>* data) { | ||
| inline Status NumPyConverter::PrepareInputData(std::shared_ptr<Buffer>* data) { | ||
| if (is_strided()) { | ||
|
||
| RETURN_NOT_OK(CopyStridedArray<ArrowType>(arr_, length_, pool_, data)); | ||
| } else if (dtype_->type_num == NPY_BOOL) { | ||
|
||
| int64_t nbytes = BitUtil::BytesForBits(length_); | ||
| std::shared_ptr<Buffer> buffer; | ||
| RETURN_NOT_OK(AllocateBuffer(pool_, nbytes, &buffer)); | ||
|
|
||
| Ndarray1DIndexer<uint8_t> values(arr_); | ||
| int64_t i = 0; | ||
| const auto generate = [&values, &i]() -> bool { return values[i++] > 0; }; | ||
| GenerateBitsUnrolled(buffer->mutable_data(), 0, length_, generate); | ||
|
||
|
|
||
| *data = buffer; | ||
| } else { | ||
| // Can zero-copy | ||
| *data = std::make_shared<NumPyBuffer>(reinterpret_cast<PyObject*>(arr_)); | ||
| } | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
| template <typename ArrowType> | ||
| inline Status NumPyConverter::ConvertData(std::shared_ptr<Buffer>* data) { | ||
| RETURN_NOT_OK(PrepareInputData<ArrowType>(data)); | ||
|
|
||
| std::shared_ptr<DataType> input_type; | ||
| RETURN_NOT_OK(NumPyDtypeToArrow(reinterpret_cast<PyObject*>(dtype_), &input_type)); | ||
|
|
||
|
|
@@ -426,38 +450,12 @@ inline Status NumPyConverter::ConvertData(std::shared_ptr<Buffer>* data) { | |
| return Status::OK(); | ||
| } | ||
|
|
||
| template <> | ||
| inline Status NumPyConverter::ConvertData<BooleanType>(std::shared_ptr<Buffer>* data) { | ||
|
||
| int64_t nbytes = BitUtil::BytesForBits(length_); | ||
| std::shared_ptr<Buffer> buffer; | ||
| RETURN_NOT_OK(AllocateBuffer(pool_, nbytes, &buffer)); | ||
|
|
||
| Ndarray1DIndexer<uint8_t> values(arr_); | ||
|
||
|
|
||
| uint8_t* bitmap = buffer->mutable_data(); | ||
|
|
||
| memset(bitmap, 0, nbytes); | ||
| for (int i = 0; i < length_; ++i) { | ||
| if (values[i] > 0) { | ||
| BitUtil::SetBit(bitmap, i); | ||
| } | ||
| } | ||
|
|
||
| *data = buffer; | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| template <> | ||
| inline Status NumPyConverter::ConvertData<Date32Type>(std::shared_ptr<Buffer>* data) { | ||
| if (is_strided()) { | ||
| RETURN_NOT_OK(CopyStridedArray<Date32Type>(arr_, length_, pool_, data)); | ||
| } else { | ||
| // Can zero-copy | ||
| *data = std::make_shared<NumPyBuffer>(reinterpret_cast<PyObject*>(arr_)); | ||
| } | ||
|
|
||
| std::shared_ptr<DataType> input_type; | ||
|
|
||
| RETURN_NOT_OK(PrepareInputData<Date32Type>(data)); | ||
|
||
|
|
||
| auto date_dtype = reinterpret_cast<PyArray_DatetimeDTypeMetaData*>(dtype_->c_metadata); | ||
| if (dtype_->type_num == NPY_DATETIME) { | ||
| // If we have inbound datetime64[D] data, this needs to be downcasted | ||
|
|
@@ -489,17 +487,11 @@ inline Status NumPyConverter::ConvertData<Date32Type>(std::shared_ptr<Buffer>* d | |
|
|
||
| template <> | ||
| inline Status NumPyConverter::ConvertData<Date64Type>(std::shared_ptr<Buffer>* data) { | ||
| if (is_strided()) { | ||
| RETURN_NOT_OK(CopyStridedArray<Date64Type>(arr_, length_, pool_, data)); | ||
| } else { | ||
| // Can zero-copy | ||
| *data = std::make_shared<NumPyBuffer>(reinterpret_cast<PyObject*>(arr_)); | ||
| } | ||
|
|
||
| constexpr int64_t kMillisecondsInDay = 86400000; | ||
|
|
||
| std::shared_ptr<DataType> input_type; | ||
|
|
||
| RETURN_NOT_OK(PrepareInputData<Date64Type>(data)); | ||
|
|
||
| auto date_dtype = reinterpret_cast<PyArray_DatetimeDTypeMetaData*>(dtype_->c_metadata); | ||
| if (dtype_->type_num == NPY_DATETIME) { | ||
| // If we have inbound datetime64[D] data, this needs to be downcasted | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -113,13 +113,13 @@ def _check_array_roundtrip(values, expected=None, mask=None, | |
| else: | ||
| assert arr.null_count == (mask | values_nulls).sum() | ||
|
|
||
| if mask is None: | ||
| tm.assert_series_equal(pd.Series(result), pd.Series(values), | ||
| check_names=False) | ||
| else: | ||
| expected = pd.Series(np.ma.masked_array(values, mask=mask)) | ||
| tm.assert_series_equal(pd.Series(result), expected, | ||
| check_names=False) | ||
| if expected is None: | ||
| if mask is None: | ||
| expected = pd.Series(values) | ||
| else: | ||
| expected = pd.Series(np.ma.masked_array(values, mask=mask)) | ||
|
|
||
| tm.assert_series_equal(pd.Series(result), expected, check_names=False) | ||
|
|
||
|
|
||
| def _check_array_from_pandas_roundtrip(np_array, type=None): | ||
|
|
@@ -559,6 +559,11 @@ def test_float_nulls_to_ints(self): | |
| assert table[0].to_pylist() == [1, 2, None] | ||
| tm.assert_frame_equal(df, table.to_pandas()) | ||
|
|
||
| def test_float_nulls_to_boolean(self): | ||
| s = pd.Series([0.0, 1.0, 2.0, None, -3.0]) | ||
| expected = pd.Series([False, True, True, None, True]) | ||
| _check_array_roundtrip(s, expected=expected, type=pa.bool_()) | ||
|
|
||
| def test_integer_no_nulls(self): | ||
| data = OrderedDict() | ||
| fields = [] | ||
|
|
@@ -672,6 +677,26 @@ def test_boolean_nulls(self): | |
|
|
||
| tm.assert_frame_equal(result, ex_frame) | ||
|
|
||
| def test_boolean_to_int(self): | ||
| # test from dtype=bool | ||
| s = pd.Series([True, True, False, True, True] * 2) | ||
| expected = pd.Series([1, 1, 0, 1, 1] * 2) | ||
| _check_array_roundtrip(s, expected=expected, type=pa.int64()) | ||
|
|
||
| def test_boolean_objects_to_int(self): | ||
| # test from dtype=object | ||
| s = pd.Series([True, True, False, True, True] * 2, dtype=object) | ||
| expected = pd.Series([1, 1, 0, 1, 1] * 2) | ||
| expected_msg = 'Expected integer, got bool' | ||
| with pytest.raises(pa.ArrowTypeError, match=expected_msg): | ||
|
||
| _check_array_roundtrip(s, expected=expected, type=pa.int64()) | ||
|
|
||
| def test_boolean_nulls_to_float(self): | ||
| # test from dtype=object | ||
| s = pd.Series([True, True, False, None, True] * 2) | ||
| expected = pd.Series([1.0, 1.0, 0.0, None, 1.0] * 2) | ||
| _check_array_roundtrip(s, expected=expected, type=pa.float64()) | ||
|
|
||
| def test_float_object_nulls(self): | ||
| arr = np.array([None, 1.5, np.float64(3.5)] * 5, dtype=object) | ||
| df = pd.DataFrame({'floats': arr}) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment or docstring here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done