Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ static auto kSortOptionsType =
GetFunctionOptionsType<SortOptions>(DataMember("sort_keys", &SortOptions::sort_keys));
static auto kPartitionNthOptionsType = GetFunctionOptionsType<PartitionNthOptions>(
DataMember("pivot", &PartitionNthOptions::pivot));
static auto kSelectKOptionsType = GetFunctionOptionsType<SelectKOptions>(
DataMember("k", &SelectKOptions::k),
DataMember("sort_keys", &SelectKOptions::sort_keys));
} // namespace
} // namespace internal

Expand Down Expand Up @@ -140,6 +143,29 @@ PartitionNthOptions::PartitionNthOptions(int64_t pivot)
: FunctionOptions(internal::kPartitionNthOptionsType), pivot(pivot) {}
constexpr char PartitionNthOptions::kTypeName[];

SelectKOptions::SelectKOptions(int64_t k, std::vector<SortKey> sort_keys)
: FunctionOptions(internal::kSelectKOptionsType),
k(k),
sort_keys(std::move(sort_keys)) {}

bool SelectKOptions::is_top_k() const {
for (const auto& k : sort_keys) {
if (k.order != SortOrder::Descending) {
return false;
}
}
return true;
}
bool SelectKOptions::is_bottom_k() const {
for (const auto& k : sort_keys) {
if (k.order != SortOrder::Ascending) {
return false;
}
}
return true;
}
constexpr char SelectKOptions::kTypeName[];

namespace internal {
void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
Expand All @@ -148,6 +174,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kArraySortOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kSortOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPartitionNthOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType));
}
} // namespace internal

Expand All @@ -162,6 +189,13 @@ Result<std::shared_ptr<Array>> NthToIndices(const Array& values, int64_t n,
return result.make_array();
}

Result<std::shared_ptr<Array>> SelectKUnstable(const Datum& datum, SelectKOptions options,
ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(Datum result,
CallFunction("select_k_unstable", {datum}, &options, ctx));
return result.make_array();
}

Result<Datum> ReplaceWithMask(const Datum& values, const Datum& mask,
const Datum& replacements, ExecContext* ctx) {
return CallFunction("replace_with_mask", {values, mask, replacements}, ctx);
Expand Down
55 changes: 55 additions & 0 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,46 @@ class ARROW_EXPORT SortOptions : public FunctionOptions {
std::vector<SortKey> sort_keys;
};

/// \brief SelectK options
class ARROW_EXPORT SelectKOptions : public FunctionOptions {
public:
explicit SelectKOptions(int64_t k = -1, std::vector<SortKey> sort_keys = {});

constexpr static char const kTypeName[] = "SelectKOptions";

static SelectKOptions Defaults() { return SelectKOptions{-1, {}}; }

static SelectKOptions TopKDefault(int64_t k, std::vector<std::string> key_names = {}) {
std::vector<SortKey> keys;
for (const auto& name : key_names) {
keys.emplace_back(SortKey(name, SortOrder::Descending));
}
if (key_names.empty()) {
keys.emplace_back(SortKey("not-used", SortOrder::Descending));
}
return SelectKOptions{k, keys};
}
static SelectKOptions BottomKDefault(int64_t k,
std::vector<std::string> key_names = {}) {
std::vector<SortKey> keys;
for (const auto& name : key_names) {
keys.emplace_back(SortKey(name, SortOrder::Ascending));
}
if (key_names.empty()) {
keys.emplace_back(SortKey("not-used", SortOrder::Ascending));
}
return SelectKOptions{k, keys};
}
bool is_top_k() const;

bool is_bottom_k() const;

/// The number of `k` elements to keep.
int64_t k;
/// Column key(s) to order by and how to order by these sort keys.
std::vector<SortKey> sort_keys;
};

/// \brief Partitioning options for NthToIndices
class ARROW_EXPORT PartitionNthOptions : public FunctionOptions {
public:
Expand Down Expand Up @@ -252,6 +292,21 @@ ARROW_EXPORT
Result<std::shared_ptr<Array>> NthToIndices(const Array& values, int64_t n,
ExecContext* ctx = NULLPTR);

/// \brief Returns the first k elements ordered by `options.keys`.
///
/// Return a sorted array with its elements rearranged in such
/// a way that the value of the element in k-th position (options.k) is in the position it
/// would be in a sorted datum ordered by `options.keys`. Null like values will be not
/// part of the output. Output is not guaranteed to be stable.
///
/// \param[in] datum datum to be partitioned
/// \param[in] options options
/// \param[in] ctx the function execution context, optional
/// \return a datum with the same schema as the input
ARROW_EXPORT
Result<std::shared_ptr<Array>> SelectKUnstable(const Datum& datum, SelectKOptions options,
ExecContext* ctx = NULLPTR);

/// \brief Returns the indices that would sort an array in the
/// specified order.
///
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ TEST(FunctionOptions, Equality) {
{SortKey("key", SortOrder::Descending), SortKey("value", SortOrder::Descending)}));
options.emplace_back(new PartitionNthOptions(/*pivot=*/0));
options.emplace_back(new PartitionNthOptions(/*pivot=*/42));
options.emplace_back(new SelectKOptions(0, {}));
options.emplace_back(new SelectKOptions(5, {{SortKey("key", SortOrder::Ascending)}}));

for (size_t i = 0; i < options.size(); i++) {
const size_t prev_i = i == 0 ? options.size() - 1 : i - 1;
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ add_arrow_compute_test(vector_test
vector_replace_test.cc
vector_selection_test.cc
vector_sort_test.cc
select_k_test.cc
test_util.cc)

add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_topk_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_replace_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_selection_benchmark PREFIX "arrow-compute")

Expand Down
Loading