diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index a882d3a955469..841275e54e3d6 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -22,8 +22,12 @@ pa_version_under4p0, pa_version_under5p0, pa_version_under6p0, + pa_version_under7p0, +) +from pandas.util._decorators import ( + deprecate_nonkeyword_arguments, + doc, ) -from pandas.util._decorators import doc from pandas.core.dtypes.common import ( is_array_like, @@ -418,6 +422,58 @@ def isna(self) -> npt.NDArray[np.bool_]: else: return self._data.is_null().to_numpy() + @deprecate_nonkeyword_arguments(version=None, allowed_args=["self"]) + def argsort( + self, + ascending: bool = True, + kind: str = "quicksort", + na_position: str = "last", + *args, + **kwargs, + ) -> np.ndarray: + order = "ascending" if ascending else "descending" + null_placement = {"last": "at_end", "first": "at_start"}.get(na_position, None) + if null_placement is None or pa_version_under7p0: + # Although pc.array_sort_indices exists in version 6 + # there's a bug that affects the pa.ChunkedArray backing + # https://issues.apache.org/jira/browse/ARROW-12042 + fallback_performancewarning("7") + return super().argsort( + ascending=ascending, kind=kind, na_position=na_position + ) + + result = pc.array_sort_indices( + self._data, order=order, null_placement=null_placement + ) + if pa_version_under2p0: + np_result = result.to_pandas().values + else: + np_result = result.to_numpy() + return np_result.astype(np.intp, copy=False) + + def _argmin_max(self, skipna: bool, method: str) -> int: + if self._data.length() in (0, self._data.null_count) or ( + self._hasna and not skipna + ): + # For empty or all null, pyarrow returns -1 but pandas expects TypeError + # For skipna=False and data w/ null, pandas expects NotImplementedError + # let ExtensionArray.arg{max|min} raise + return getattr(super(), f"arg{method}")(skipna=skipna) + + if pa_version_under6p0: + raise NotImplementedError( + f"arg{method} only implemented for pyarrow version >= 6.0" + ) + + value = getattr(pc, method)(self._data, skip_nulls=skipna) + return pc.index(self._data, value).as_py() + + def argmin(self, skipna: bool = True) -> int: + return self._argmin_max(skipna, "min") + + def argmax(self, skipna: bool = True) -> int: + return self._argmin_max(skipna, "max") + def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT: """ Return a shallow copy of the array. diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index a2a96da02b2a6..2f3482ddc4811 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1385,6 +1385,11 @@ def test_value_counts_with_normalize(self, data, request): ) super().test_value_counts_with_normalize(data) + @pytest.mark.xfail( + pa_version_under6p0, + raises=NotImplementedError, + reason="argmin/max only implemented for pyarrow version >= 6.0", + ) def test_argmin_argmax( self, data_for_sorting, data_missing_for_sorting, na_value, request ): @@ -1395,8 +1400,50 @@ def test_argmin_argmax( reason=f"{pa_dtype} only has 2 unique possible values", ) ) + elif pa.types.is_duration(pa_dtype): + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowNotImplementedError, + reason=f"min_max not supported in pyarrow for {pa_dtype}", + ) + ) super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value) + @pytest.mark.parametrize( + "op_name, skipna, expected", + [ + ("idxmax", True, 0), + ("idxmin", True, 2), + ("argmax", True, 0), + ("argmin", True, 2), + ("idxmax", False, np.nan), + ("idxmin", False, np.nan), + ("argmax", False, -1), + ("argmin", False, -1), + ], + ) + def test_argreduce_series( + self, data_missing_for_sorting, op_name, skipna, expected, request + ): + pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype + if pa_version_under6p0 and skipna: + request.node.add_marker( + pytest.mark.xfail( + raises=NotImplementedError, + reason="min_max not supported in pyarrow", + ) + ) + elif not pa_version_under6p0 and pa.types.is_duration(pa_dtype) and skipna: + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowNotImplementedError, + reason=f"min_max not supported in pyarrow for {pa_dtype}", + ) + ) + super().test_argreduce_series( + data_missing_for_sorting, op_name, skipna, expected + ) + @pytest.mark.parametrize("ascending", [True, False]) def test_sort_values(self, data_for_sorting, ascending, sort_by_key, request): pa_dtype = data_for_sorting.dtype.pyarrow_dtype diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 6cea21b6672d8..e4293d6d70e38 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -167,7 +167,48 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna): class TestMethods(base.BaseMethodsTests): - pass + def test_argmin_argmax( + self, data_for_sorting, data_missing_for_sorting, na_value, request + ): + if pa_version_under6p0 and data_missing_for_sorting.dtype.storage == "pyarrow": + request.node.add_marker( + pytest.mark.xfail( + raises=NotImplementedError, + reason="min_max not supported in pyarrow", + ) + ) + super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value) + + @pytest.mark.parametrize( + "op_name, skipna, expected", + [ + ("idxmax", True, 0), + ("idxmin", True, 2), + ("argmax", True, 0), + ("argmin", True, 2), + ("idxmax", False, np.nan), + ("idxmin", False, np.nan), + ("argmax", False, -1), + ("argmin", False, -1), + ], + ) + def test_argreduce_series( + self, data_missing_for_sorting, op_name, skipna, expected, request + ): + if ( + pa_version_under6p0 + and data_missing_for_sorting.dtype.storage == "pyarrow" + and skipna + ): + request.node.add_marker( + pytest.mark.xfail( + raises=NotImplementedError, + reason="min_max not supported in pyarrow", + ) + ) + super().test_argreduce_series( + data_missing_for_sorting, op_name, skipna, expected + ) class TestCasting(base.BaseCastingTests): diff --git a/pandas/tests/indexes/test_common.py b/pandas/tests/indexes/test_common.py index d582a469eaf0e..e7e971f957e48 100644 --- a/pandas/tests/indexes/test_common.py +++ b/pandas/tests/indexes/test_common.py @@ -10,7 +10,7 @@ from pandas.compat import ( IS64, - pa_version_under2p0, + pa_version_under7p0, ) from pandas.core.dtypes.common import is_integer_dtype @@ -396,11 +396,16 @@ def test_astype_preserves_name(self, index, dtype): # imaginary components discarded warn = np.ComplexWarning + is_pyarrow_str = ( + str(index.dtype) == "string[pyarrow]" + and pa_version_under7p0 + and dtype == "category" + ) try: # Some of these conversions cannot succeed so we use a try / except with tm.assert_produces_warning( warn, - raise_on_extra_warnings=not pa_version_under2p0, + raise_on_extra_warnings=is_pyarrow_str, ): result = index.astype(dtype) except (ValueError, TypeError, NotImplementedError, SystemError): diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index f38a6c89e1bcb..45ecd09e550d0 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -8,6 +8,8 @@ import numpy as np import pytest +from pandas.compat import pa_version_under7p0 + from pandas.core.dtypes.cast import find_common_type from pandas import ( @@ -177,7 +179,8 @@ def test_dunder_inplace_setops_deprecated(index): with tm.assert_produces_warning(FutureWarning): index &= index - with tm.assert_produces_warning(FutureWarning): + is_pyarrow = str(index.dtype) == "string[pyarrow]" and pa_version_under7p0 + with tm.assert_produces_warning(FutureWarning, raise_on_extra_warnings=is_pyarrow): index ^= index