Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,23 +147,6 @@ SelectKOptions::SelectKOptions(int64_t k, std::vector<SortKey> sort_keys)
: FunctionOptions(internal::kSelectKOptionsType),
k(k),
sort_keys(std::move(sort_keys)) {}

bool SelectKOptions::is_top_k() const {
for (const auto& k : sort_keys) {
if (k.order != SortOrder::Descending) {
return false;
}
}
return true;
}
bool SelectKOptions::is_bottom_k() const {
for (const auto& k : sort_keys) {
if (k.order != SortOrder::Ascending) {
return false;
}
}
return true;
}
constexpr char SelectKOptions::kTypeName[];

namespace internal {
Expand All @@ -189,7 +172,8 @@ Result<std::shared_ptr<Array>> NthToIndices(const Array& values, int64_t n,
return result.make_array();
}

Result<std::shared_ptr<Array>> SelectKUnstable(const Datum& datum, SelectKOptions options,
Result<std::shared_ptr<Array>> SelectKUnstable(const Datum& datum,
const SelectKOptions& options,
ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(Datum result,
CallFunction("select_k_unstable", {datum}, &options, ctx));
Expand Down
18 changes: 9 additions & 9 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ class ARROW_EXPORT SelectKOptions : public FunctionOptions {
}
return SelectKOptions{k, keys};
}
bool is_top_k() const;

bool is_bottom_k() const;

/// The number of `k` elements to keep.
int64_t k;
Expand Down Expand Up @@ -292,19 +289,22 @@ ARROW_EXPORT
Result<std::shared_ptr<Array>> NthToIndices(const Array& values, int64_t n,
ExecContext* ctx = NULLPTR);

/// \brief Returns the first k elements ordered by `options.keys`.
/// \brief Returns the indices that would select the first `k` elements of the array in
/// the specified order.
///
/// Return a sorted array with its elements rearranged in such
/// a way that the value of the element in k-th position (options.k) is in the position it
/// would be in a sorted datum ordered by `options.keys`. Null like values will be not
/// part of the output. Output is not guaranteed to be stable.
// Perform an indirect sort of the datum, keeping only the first `k` elements. The output
// array will contain indices such that the item indicated by the k-th index will be in
// the position it would be if the datum were sorted by `options.sort_keys`. However,
// indices of null values will not be part of the output. The sort is not guaranteed to be
// stable.
///
/// \param[in] datum datum to be partitioned
/// \param[in] options options
/// \param[in] ctx the function execution context, optional
/// \return a datum with the same schema as the input
ARROW_EXPORT
Result<std::shared_ptr<Array>> SelectKUnstable(const Datum& datum, SelectKOptions options,
Result<std::shared_ptr<Array>> SelectKUnstable(const Datum& datum,
const SelectKOptions& options,
ExecContext* ctx = NULLPTR);

/// \brief Returns the indices that would sort an array in the
Expand Down
52 changes: 16 additions & 36 deletions cpp/src/arrow/compute/kernels/select_k_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ using internal::checked_pointer_cast;
namespace compute {

template <typename ArrayType, SortOrder order>
class SelectKComparator {
class SelectKCompareForResult {
public:
template <typename Type>
bool operator()(const Type& lval, const Type& rval) {
Expand All @@ -63,16 +63,6 @@ Result<std::shared_ptr<Array>> SelectK(const Datum& values, int64_t k) {
}
}

template <SortOrder order>
Result<std::shared_ptr<Array>> SelectK(const Datum& values,
const SelectKOptions& options) {
if (order == SortOrder::Descending) {
return SelectKUnstable(Datum(values), options);
} else {
return SelectKUnstable(Datum(values), options);
}
}

void ValidateSelectK(const Datum& datum, Array& select_k_indices, SortOrder order,
bool stable_sort = false) {
ASSERT_TRUE(datum.is_arraylike());
Expand Down Expand Up @@ -298,7 +288,7 @@ template <typename ArrayType, SortOrder order>
void ValidateSelectKIndices(const ArrayType& array) {
ValidateOutput(array);

SelectKComparator<ArrayType, order> compare;
SelectKCompareForResult<ArrayType, order> compare;
for (uint64_t i = 1; i < static_cast<uint64_t>(array.length()); i++) {
using ArrowType = typename ArrayType::TypeClass;
using GetView = internal::GetViewType<ArrowType>;
Expand Down Expand Up @@ -365,7 +355,6 @@ TYPED_TEST_SUITE(TestBottomKChunkedArrayRandom, SelectKableTypes);
TYPED_TEST(TestBottomKChunkedArrayRandom, BottomK) { this->TestSelectK(1000); }

// // Test basic cases for record batch.
template <SortOrder order>
class TestSelectKWithRecordBatch : public ::testing::Test {
public:
void Check(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
Expand All @@ -378,7 +367,7 @@ class TestSelectKWithRecordBatch : public ::testing::Test {
Status DoSelectK(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
const SelectKOptions& options, std::shared_ptr<RecordBatch>* out) {
auto batch = RecordBatchFromJSON(schm, batch_json);
ARROW_ASSIGN_OR_RAISE(auto indices, SelectK<order>(Datum(*batch), options));
ARROW_ASSIGN_OR_RAISE(auto indices, SelectKUnstable(Datum(*batch), options));

ValidateOutput(*indices);
ARROW_ASSIGN_OR_RAISE(
Expand All @@ -388,9 +377,7 @@ class TestSelectKWithRecordBatch : public ::testing::Test {
}
};

struct TestTopKWithRecordBatch : TestSelectKWithRecordBatch<SortOrder::Descending> {};

TEST_F(TestTopKWithRecordBatch, NoNull) {
TEST_F(TestSelectKWithRecordBatch, TopKNoNull) {
auto schema = ::arrow::schema({
{field("a", uint8())},
{field("b", uint32())},
Expand All @@ -417,7 +404,7 @@ TEST_F(TestTopKWithRecordBatch, NoNull) {
Check(schema, batch_input, options, expected_batch);
}

TEST_F(TestTopKWithRecordBatch, Null) {
TEST_F(TestSelectKWithRecordBatch, TopKNull) {
auto schema = ::arrow::schema({
{field("a", uint8())},
{field("b", uint32())},
Expand All @@ -444,7 +431,7 @@ TEST_F(TestTopKWithRecordBatch, Null) {
Check(schema, batch_input, options, expected_batch);
}

TEST_F(TestTopKWithRecordBatch, OneColumnKey) {
TEST_F(TestSelectKWithRecordBatch, TopKOneColumnKey) {
auto schema = ::arrow::schema({
{field("country", utf8())},
{field("population", uint64())},
Expand Down Expand Up @@ -473,7 +460,7 @@ TEST_F(TestTopKWithRecordBatch, OneColumnKey) {
this->Check(schema, batch_input, options, expected_batch);
}

TEST_F(TestTopKWithRecordBatch, MultipleColumnKeys) {
TEST_F(TestSelectKWithRecordBatch, TopKMultipleColumnKeys) {
auto schema = ::arrow::schema({{field("country", utf8())},
{field("population", uint64())},
{field("GDP", uint64())}});
Expand All @@ -499,9 +486,7 @@ TEST_F(TestTopKWithRecordBatch, MultipleColumnKeys) {
this->Check(schema, batch_input, options, expected_batch);
}

struct TestBottomKWithRecordBatch : TestSelectKWithRecordBatch<SortOrder::Ascending> {};

TEST_F(TestBottomKWithRecordBatch, NoNull) {
TEST_F(TestSelectKWithRecordBatch, BottomKNoNull) {
auto schema = ::arrow::schema({
{field("a", uint8())},
{field("b", uint32())},
Expand All @@ -528,7 +513,7 @@ TEST_F(TestBottomKWithRecordBatch, NoNull) {
Check(schema, batch_input, options, expected_batch);
}

TEST_F(TestBottomKWithRecordBatch, Null) {
TEST_F(TestSelectKWithRecordBatch, BottomKNull) {
auto schema = ::arrow::schema({
{field("a", uint8())},
{field("b", uint32())},
Expand All @@ -555,7 +540,7 @@ TEST_F(TestBottomKWithRecordBatch, Null) {
Check(schema, batch_input, options, expected_batch);
}

TEST_F(TestBottomKWithRecordBatch, OneColumnKey) {
TEST_F(TestSelectKWithRecordBatch, BottomKOneColumnKey) {
auto schema = ::arrow::schema({
{field("country", utf8())},
{field("population", uint64())},
Expand Down Expand Up @@ -584,7 +569,7 @@ TEST_F(TestBottomKWithRecordBatch, OneColumnKey) {
this->Check(schema, batch_input, options, expected_batch);
}

TEST_F(TestBottomKWithRecordBatch, MultipleColumnKeys) {
TEST_F(TestSelectKWithRecordBatch, BottomKMultipleColumnKeys) {
auto schema = ::arrow::schema({{field("country", utf8())},
{field("population", uint64())},
{field("GDP", uint64())}});
Expand Down Expand Up @@ -612,7 +597,6 @@ TEST_F(TestBottomKWithRecordBatch, MultipleColumnKeys) {
}

// Test basic cases for table.
template <SortOrder order>
struct TestSelectKWithTable : public ::testing::Test {
void Check(const std::shared_ptr<Schema>& schm,
const std::vector<std::string>& input_json, const SelectKOptions& options,
Expand All @@ -626,7 +610,7 @@ struct TestSelectKWithTable : public ::testing::Test {
const std::vector<std::string>& input_json,
const SelectKOptions& options, std::shared_ptr<Table>* out) {
auto table = TableFromJSON(schm, input_json);
ARROW_ASSIGN_OR_RAISE(auto indices, SelectK<order>(Datum(*table), options));
ARROW_ASSIGN_OR_RAISE(auto indices, SelectKUnstable(Datum(*table), options));
ValidateOutput(*indices);

ARROW_ASSIGN_OR_RAISE(
Expand All @@ -636,9 +620,7 @@ struct TestSelectKWithTable : public ::testing::Test {
}
};

struct TestTopKWithTable : TestSelectKWithTable<SortOrder::Descending> {};

TEST_F(TestTopKWithTable, OneColumnKey) {
TEST_F(TestSelectKWithTable, TopKOneColumnKey) {
auto schema = ::arrow::schema({
{field("a", uint8())},
{field("b", uint32())},
Expand All @@ -661,7 +643,7 @@ TEST_F(TestTopKWithTable, OneColumnKey) {
Check(schema, input, options, expected);
}

TEST_F(TestTopKWithTable, MultipleColumnKeys) {
TEST_F(TestSelectKWithTable, TopKMultipleColumnKeys) {
auto schema = ::arrow::schema({
{field("a", uint8())},
{field("b", uint32())},
Expand All @@ -684,9 +666,7 @@ TEST_F(TestTopKWithTable, MultipleColumnKeys) {
Check(schema, input, options, expected);
}

struct TestBottomKWithTable : TestSelectKWithTable<SortOrder::Ascending> {};

TEST_F(TestBottomKWithTable, OneColumnKey) {
TEST_F(TestSelectKWithTable, BottomKOneColumnKey) {
auto schema = ::arrow::schema({
{field("a", uint8())},
{field("b", uint32())},
Expand All @@ -709,7 +689,7 @@ TEST_F(TestBottomKWithTable, OneColumnKey) {
Check(schema, input, options, expected);
}

TEST_F(TestBottomKWithTable, MultipleColumnKeys) {
TEST_F(TestSelectKWithTable, BottomKMultipleColumnKeys) {
auto schema = ::arrow::schema({
{field("a", uint8())},
{field("b", uint32())},
Expand Down
Loading