Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cpp/src/arrow/compute/kernels/vector_sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2337,9 +2337,12 @@ class SelectKUnstableMetaFunction : public MetaFunction {
const FunctionOptions* options, ExecContext* ctx) const {
const SelectKOptions& select_k_options = static_cast<const SelectKOptions&>(*options);
if (select_k_options.k < 0) {
return Status::Invalid("SelectK requires a nonnegative `k`, got ",
return Status::Invalid("select_k_unstable requires a nonnegative `k`, got ",
select_k_options.k);
}
if (select_k_options.sort_keys.size() == 0) {
return Status::Invalid("select_k_unstable requires a non-empty `sort_keys`");
}
switch (args[0].kind()) {
case Datum::ARRAY: {
return SelectKth(*args[0].make_array(), select_k_options, ctx);
Expand Down
82 changes: 82 additions & 0 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,3 +643,85 @@ def fill_null(values, fill_value):
fill_value = pa.scalar(fill_value.as_py(), type=values.type)

return call_function("coalesce", [values, fill_value])


def top_k_unstable(values, k, sort_keys=None, memory_pool=None):
"""
Select the indices of the top-k ordered elements from array- or table-like
data.

This is a specialization for :func:`select_k_unstable`. Output is not
guaranteed to be stable.

Parameters
----------
values : Array, ChunkedArray, RecordBatch, or Table
k : The number of `k` elements to keep.
sort_keys : Column key names to order by when input is table-like data.

Returns
-------
result : Array of indices

Examples
--------
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>> arr = pa.array(["a", "b", "c", None, "e", "f"])
>>> pc.top_k_unstable(arr, k=3)
<pyarrow.lib.UInt64Array object at 0x7fdcb19d7f30>
[
5,
4,
2
]
"""
if sort_keys is None:
sort_keys = []
if isinstance(values, (pa.Array, pa.ChunkedArray)):
sort_keys.append(("dummy", "descending"))
else:
sort_keys = map(lambda key_name: (key_name, "descending"), sort_keys)
options = SelectKOptions(k=k, sort_keys=sort_keys)
return call_function("select_k_unstable", [values], options, memory_pool)


def bottom_k_unstable(values, k, sort_keys=None, memory_pool=None):
"""
Select the indices of the bottom-k ordered elements from
array- or table-like data.

This is a specialization for :func:`select_k_unstable`. Output is not
guaranteed to be stable.

Parameters
----------
values : Array, ChunkedArray, RecordBatch, or Table
k : The number of `k` elements to keep.
sort_keys : Column key names to order by when input is table-like data.

Returns
-------
result : Array of indices

Examples
--------
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>> arr = pa.array(["a", "b", "c", None, "e", "f"])
>>> pc.bottom_k_unstable(arr, k=3)
<pyarrow.lib.UInt64Array object at 0x7fdcb19d7fa0>
[
0,
1,
2
]
"""
if sort_keys is None:
sort_keys = []
if isinstance(values, (pa.Array, pa.ChunkedArray)):
sort_keys.append(("dummy", "ascending"))
else:
sort_keys = map(lambda key_name: (key_name, "ascending"), sort_keys)
options = SelectKOptions(k=k, sort_keys=sort_keys)
return call_function("select_k_unstable", [values], options, memory_pool)
52 changes: 36 additions & 16 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,35 +1859,42 @@ def validate_select_k(select_k_indices, arr, order, stable_sort=False):
assert actual == expected

arr = pa.array([1, 2, None, 0])
for order in ["descending", "ascending"]:
for k in [0, 2, 4]:
for k in [0, 2, 4]:
for order in ["descending", "ascending"]:
result = pc.select_k_unstable(
arr, k=k, sort_keys=[("dummy", order)])
validate_select_k(result, arr, order)

result = pc.select_k_unstable(arr, options=pc.SelectKOptions(
k=2, sort_keys=[("dummy", "descending")]))
result = pc.top_k_unstable(arr, k=k)
validate_select_k(result, arr, "descending")

result = pc.bottom_k_unstable(arr, k=k)
validate_select_k(result, arr, "ascending")

result = pc.select_k_unstable(
arr, options=pc.SelectKOptions(
k=2, sort_keys=[("dummy", "descending")])
)
validate_select_k(result, arr, "descending")

result = pc.select_k_unstable(arr, options=pc.SelectKOptions(
k=2, sort_keys=[("dummy", "ascending")]))
result = pc.select_k_unstable(
arr, options=pc.SelectKOptions(k=2, sort_keys=[("dummy", "ascending")])
)
validate_select_k(result, arr, "ascending")


def test_select_k_table():
table = pa.table({"a": [1, 2, 0], "b": [1, 0, 1]})

def validate_select_k(select_k_indices, table, sort_keys,
stable_sort=False):
sorted_indices = pc.sort_indices(table, sort_keys=sort_keys)
def validate_select_k(select_k_indices, tbl, sort_keys, stable_sort=False):
sorted_indices = pc.sort_indices(tbl, sort_keys=sort_keys)
head_k_indices = sorted_indices.slice(0, len(select_k_indices))
if stable_sort:
assert select_k_indices == head_k_indices
else:
expected = pc.take(table, head_k_indices)
actual = pc.take(table, select_k_indices)
expected = pc.take(tbl, head_k_indices)
actual = pc.take(tbl, select_k_indices)
assert actual == expected

table = pa.table({"a": [1, 2, 0], "b": [1, 0, 1]})
for k in [0, 2, 4]:
result = pc.select_k_unstable(
table, k=k, sort_keys=[("a", "ascending")])
Expand All @@ -1896,13 +1903,26 @@ def validate_select_k(select_k_indices, table, sort_keys,
result = pc.select_k_unstable(
table, k=k, sort_keys=[("a", "ascending"), ("b", "ascending")]
)
validate_select_k(result, table, sort_keys=[
("a", "ascending"), ("b", "ascending")])
validate_select_k(
result, table, sort_keys=[("a", "ascending"), ("b", "ascending")]
)

result = pc.top_k_unstable(table, k=k, sort_keys=["a"])
validate_select_k(result, table, sort_keys=[("a", "descending")])

result = pc.bottom_k_unstable(table, k=k, sort_keys=["a", "b"])
validate_select_k(
result, table, sort_keys=[("a", "ascending"), ("b", "ascending")]
)

with pytest.raises(ValueError,
match="SelectK requires a nonnegative `k`"):
match="select_k_unstable requires a nonnegative `k`"):
pc.select_k_unstable(table)

with pytest.raises(ValueError, match="select_k_unstable requires "
"a non-empty `sort_keys`"):
pc.select_k_unstable(table, k=2)

with pytest.raises(ValueError, match="not a valid order"):
pc.select_k_unstable(table, k=k, sort_keys=[("a", "nonscending")])

Expand Down