From 4a4ba209a9e6241716e99746dcb6df26a2c72293 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 16 Sep 2021 12:28:53 -0500 Subject: [PATCH] minor fix minor fixes --- cpp/src/arrow/compute/kernels/vector_sort.cc | 5 +- python/pyarrow/compute.py | 82 ++++++++++++++++++++ python/pyarrow/tests/test_compute.py | 52 +++++++++---- 3 files changed, 122 insertions(+), 17 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index da6d9cabac8..a2bbc30ed42 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -2337,9 +2337,12 @@ class SelectKUnstableMetaFunction : public MetaFunction { const FunctionOptions* options, ExecContext* ctx) const { const SelectKOptions& select_k_options = static_cast(*options); if (select_k_options.k < 0) { - return Status::Invalid("SelectK requires a nonnegative `k`, got ", + return Status::Invalid("select_k_unstable requires a nonnegative `k`, got ", select_k_options.k); } + if (select_k_options.sort_keys.size() == 0) { + return Status::Invalid("select_k_unstable requires a non-empty `sort_keys`"); + } switch (args[0].kind()) { case Datum::ARRAY: { return SelectKth(*args[0].make_array(), select_k_options, ctx); diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index d0112f26130..4155ec1ad38 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -643,3 +643,85 @@ def fill_null(values, fill_value): fill_value = pa.scalar(fill_value.as_py(), type=values.type) return call_function("coalesce", [values, fill_value]) + + +def top_k_unstable(values, k, sort_keys=None, memory_pool=None): + """ + Select the indices of the top-k ordered elements from array- or table-like + data. + + This is a specialization for :func:`select_k_unstable`. Output is not + guaranteed to be stable. + + Parameters + ---------- + values : Array, ChunkedArray, RecordBatch, or Table + k : The number of `k` elements to keep. + sort_keys : Column key names to order by when input is table-like data. + + Returns + ------- + result : Array of indices + + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> arr = pa.array(["a", "b", "c", None, "e", "f"]) + >>> pc.top_k_unstable(arr, k=3) + + [ + 5, + 4, + 2 + ] + """ + if sort_keys is None: + sort_keys = [] + if isinstance(values, (pa.Array, pa.ChunkedArray)): + sort_keys.append(("dummy", "descending")) + else: + sort_keys = map(lambda key_name: (key_name, "descending"), sort_keys) + options = SelectKOptions(k=k, sort_keys=sort_keys) + return call_function("select_k_unstable", [values], options, memory_pool) + + +def bottom_k_unstable(values, k, sort_keys=None, memory_pool=None): + """ + Select the indices of the bottom-k ordered elements from + array- or table-like data. + + This is a specialization for :func:`select_k_unstable`. Output is not + guaranteed to be stable. + + Parameters + ---------- + values : Array, ChunkedArray, RecordBatch, or Table + k : The number of `k` elements to keep. + sort_keys : Column key names to order by when input is table-like data. + + Returns + ------- + result : Array of indices + + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> arr = pa.array(["a", "b", "c", None, "e", "f"]) + >>> pc.bottom_k_unstable(arr, k=3) + + [ + 0, + 1, + 2 + ] + """ + if sort_keys is None: + sort_keys = [] + if isinstance(values, (pa.Array, pa.ChunkedArray)): + sort_keys.append(("dummy", "ascending")) + else: + sort_keys = map(lambda key_name: (key_name, "ascending"), sort_keys) + options = SelectKOptions(k=k, sort_keys=sort_keys) + return call_function("select_k_unstable", [values], options, memory_pool) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 46ba78027b5..b33e482df89 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -1859,35 +1859,42 @@ def validate_select_k(select_k_indices, arr, order, stable_sort=False): assert actual == expected arr = pa.array([1, 2, None, 0]) - for order in ["descending", "ascending"]: - for k in [0, 2, 4]: + for k in [0, 2, 4]: + for order in ["descending", "ascending"]: result = pc.select_k_unstable( arr, k=k, sort_keys=[("dummy", order)]) validate_select_k(result, arr, order) - result = pc.select_k_unstable(arr, options=pc.SelectKOptions( - k=2, sort_keys=[("dummy", "descending")])) + result = pc.top_k_unstable(arr, k=k) + validate_select_k(result, arr, "descending") + + result = pc.bottom_k_unstable(arr, k=k) + validate_select_k(result, arr, "ascending") + + result = pc.select_k_unstable( + arr, options=pc.SelectKOptions( + k=2, sort_keys=[("dummy", "descending")]) + ) validate_select_k(result, arr, "descending") - result = pc.select_k_unstable(arr, options=pc.SelectKOptions( - k=2, sort_keys=[("dummy", "ascending")])) + result = pc.select_k_unstable( + arr, options=pc.SelectKOptions(k=2, sort_keys=[("dummy", "ascending")]) + ) validate_select_k(result, arr, "ascending") def test_select_k_table(): - table = pa.table({"a": [1, 2, 0], "b": [1, 0, 1]}) - - def validate_select_k(select_k_indices, table, sort_keys, - stable_sort=False): - sorted_indices = pc.sort_indices(table, sort_keys=sort_keys) + def validate_select_k(select_k_indices, tbl, sort_keys, stable_sort=False): + sorted_indices = pc.sort_indices(tbl, sort_keys=sort_keys) head_k_indices = sorted_indices.slice(0, len(select_k_indices)) if stable_sort: assert select_k_indices == head_k_indices else: - expected = pc.take(table, head_k_indices) - actual = pc.take(table, select_k_indices) + expected = pc.take(tbl, head_k_indices) + actual = pc.take(tbl, select_k_indices) assert actual == expected + table = pa.table({"a": [1, 2, 0], "b": [1, 0, 1]}) for k in [0, 2, 4]: result = pc.select_k_unstable( table, k=k, sort_keys=[("a", "ascending")]) @@ -1896,13 +1903,26 @@ def validate_select_k(select_k_indices, table, sort_keys, result = pc.select_k_unstable( table, k=k, sort_keys=[("a", "ascending"), ("b", "ascending")] ) - validate_select_k(result, table, sort_keys=[ - ("a", "ascending"), ("b", "ascending")]) + validate_select_k( + result, table, sort_keys=[("a", "ascending"), ("b", "ascending")] + ) + + result = pc.top_k_unstable(table, k=k, sort_keys=["a"]) + validate_select_k(result, table, sort_keys=[("a", "descending")]) + + result = pc.bottom_k_unstable(table, k=k, sort_keys=["a", "b"]) + validate_select_k( + result, table, sort_keys=[("a", "ascending"), ("b", "ascending")] + ) with pytest.raises(ValueError, - match="SelectK requires a nonnegative `k`"): + match="select_k_unstable requires a nonnegative `k`"): pc.select_k_unstable(table) + with pytest.raises(ValueError, match="select_k_unstable requires " + "a non-empty `sort_keys`"): + pc.select_k_unstable(table, k=2) + with pytest.raises(ValueError, match="not a valid order"): pc.select_k_unstable(table, k=k, sort_keys=[("a", "nonscending")])