Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions cpp/src/arrow/compute/exec/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <vector>

#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/exec/expression.h"
#include "arrow/util/optional.h"
Expand Down Expand Up @@ -111,5 +112,19 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions {
std::function<Future<util::optional<ExecBatch>>()>* generator;
};

/// \brief Make a node which sorts rows passed through it
///
/// All batches pushed to this node will be accumulated, then sorted, by the given
/// fields. Then sorted batches will be forwarded to the generator in sorted order.
class ARROW_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions {
public:
explicit OrderBySinkNodeOptions(
SortOptions sort_options,
std::function<Future<util::optional<ExecBatch>>()>* generator)
: SinkNodeOptions(generator), sort_options(std::move(sort_options)) {}

SortOptions sort_options;
};

} // namespace compute
} // namespace arrow
110 changes: 110 additions & 0 deletions cpp/src/arrow/compute/exec/plan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/test_util.h"
#include "arrow/compute/exec/util.h"
#include "arrow/record_batch.h"
#include "arrow/table.h"
#include "arrow/testing/future_util.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/matchers.h"
Expand All @@ -36,6 +38,7 @@
#include "arrow/util/vector.h"

using testing::ElementsAre;
using testing::ElementsAreArray;
using testing::HasSubstr;
using testing::Optional;
using testing::UnorderedElementsAreArray;
Expand Down Expand Up @@ -262,6 +265,7 @@ BatchesWithSchema MakeBasicBatches() {
BatchesWithSchema MakeRandomBatches(const std::shared_ptr<Schema>& schema,
int num_batches = 10, int batch_size = 4) {
BatchesWithSchema out;
out.schema = schema;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


random::RandomArrayGenerator rng(42);
out.batches.resize(num_batches);
Expand Down Expand Up @@ -301,6 +305,36 @@ TEST(ExecPlanExecution, SourceSink) {
}
}

TEST(ExecPlanExecution, SourceOrderBy) {
std::vector<ExecBatch> expected = {
ExecBatchFromJSON({int32(), boolean()},
"[[4, false], [5, null], [6, false], [7, false], [null, true]]")};
for (bool slow : {false, true}) {
SCOPED_TRACE(slow ? "slowed" : "unslowed");

for (bool parallel : {false, true}) {
SCOPED_TRACE(parallel ? "parallel" : "single threaded");

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator<util::optional<ExecBatch>> sink_gen;

auto basic_data = MakeBasicBatches();

SortOptions options({SortKey("i32", SortOrder::Ascending)});
ASSERT_OK(Declaration::Sequence(
{
{"source", SourceNodeOptions{basic_data.schema,
basic_data.gen(parallel, slow)}},
{"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
})
.AddToPlan(plan.get()));

ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
Finishes(ResultWith(ElementsAreArray(expected))));
}
}
}

TEST(ExecPlanExecution, SourceSinkError) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
Expand Down Expand Up @@ -355,6 +389,43 @@ TEST(ExecPlanExecution, StressSourceSink) {
}
}

TEST(ExecPlanExecution, StressSourceOrderBy) {
auto input_schema = schema({field("a", int32()), field("b", boolean())});
for (bool slow : {false, true}) {
SCOPED_TRACE(slow ? "slowed" : "unslowed");

for (bool parallel : {false, true}) {
SCOPED_TRACE(parallel ? "parallel" : "single threaded");

int num_batches = slow && !parallel ? 30 : 300;

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator<util::optional<ExecBatch>> sink_gen;

auto random_data = MakeRandomBatches(input_schema, num_batches);

SortOptions options({SortKey("a", SortOrder::Ascending)});
ASSERT_OK(Declaration::Sequence(
{
{"source", SourceNodeOptions{random_data.schema,
random_data.gen(parallel, slow)}},
{"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
})
.AddToPlan(plan.get()));

// Check that data is sorted appropriately
ASSERT_FINISHES_OK_AND_ASSIGN(auto exec_batches,
StartAndCollect(plan.get(), sink_gen));
ASSERT_OK_AND_ASSIGN(auto actual, TableFromExecBatches(input_schema, exec_batches));
ASSERT_OK_AND_ASSIGN(auto original,
TableFromExecBatches(input_schema, random_data.batches));
ASSERT_OK_AND_ASSIGN(auto sort_indices, SortIndices(original, options));
ASSERT_OK_AND_ASSIGN(auto expected, Take(original, sort_indices));
AssertTablesEqual(*actual, *expected.table());
}
}
}

TEST(ExecPlanExecution, StressSourceSinkStopped) {
for (bool slow : {false, true}) {
SCOPED_TRACE(slow ? "slowed" : "unslowed");
Expand Down Expand Up @@ -541,6 +612,45 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumFilter) {
}
}

TEST(ExecPlanExecution, SourceFilterProjectGroupedSumOrderBy) {
for (bool parallel : {false, true}) {
SCOPED_TRACE(parallel ? "parallel/merged" : "serial");

int batch_multiplicity = parallel ? 100 : 1;
auto input = MakeGroupableBatches(/*multiplicity=*/batch_multiplicity);

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator<util::optional<ExecBatch>> sink_gen;

SortOptions options({SortKey("str", SortOrder::Descending)});
ASSERT_OK(
Declaration::Sequence(
{
{"source",
SourceNodeOptions{input.schema, input.gen(parallel, /*slow=*/false)}},
{"filter",
FilterNodeOptions{greater_equal(field_ref("i32"), literal(0))}},
{"project", ProjectNodeOptions{{
field_ref("str"),
call("multiply", {field_ref("i32"), literal(2)}),
}}},
{"aggregate", AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
/*targets=*/{"multiply(i32, 2)"},
/*names=*/{"sum(multiply(i32, 2))"},
/*keys=*/{"str"}}},
{"filter", FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
literal(10 * batch_multiplicity))}},
{"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
})
.AddToPlan(plan.get()));

ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
Finishes(ResultWith(ElementsAreArray({ExecBatchFromJSON(
{int64(), utf8()}, parallel ? R"([[2000, "beta"], [3600, "alfa"]])"
: R"([[20, "beta"], [36, "alfa"]])")}))));
}
}

TEST(ExecPlanExecution, SourceScalarAggSink) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
Expand Down
81 changes: 79 additions & 2 deletions cpp/src/arrow/compute/exec/sink_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@

#include <mutex>

#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/util.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/datum.h"
#include "arrow/result.h"
#include "arrow/table.h"
#include "arrow/util/async_generator.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/future.h"
Expand Down Expand Up @@ -135,8 +137,8 @@ class SinkNode : public ExecNode {
}
}

private:
void Finish() {
protected:
virtual void Finish() {
if (producer_.Close()) {
finished_.MarkFinished();
}
Expand All @@ -148,7 +150,82 @@ class SinkNode : public ExecNode {
PushGenerator<util::optional<ExecBatch>>::Producer producer_;
};

// A sink node that accumulates inputs, then sorts them before emitting them.
struct OrderBySinkNode final : public SinkNode {
OrderBySinkNode(ExecPlan* plan, std::vector<ExecNode*> inputs, SortOptions sort_options,
AsyncGenerator<util::optional<ExecBatch>>* generator)
: SinkNode(plan, std::move(inputs), generator),
sort_options_(std::move(sort_options)) {}

const char* kind_name() override { return "OrderBySinkNode"; }

static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderBySinkNode"));

const auto& sink_options = checked_cast<const OrderBySinkNodeOptions&>(options);
return plan->EmplaceNode<OrderBySinkNode>(
plan, std::move(inputs), sink_options.sort_options, sink_options.generator);
}

void InputReceived(ExecNode* input, int seq, ExecBatch batch) override {
DCHECK_EQ(input, inputs_[0]);

// Accumulate data
{
std::unique_lock<std::mutex> lock(mutex_);
auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(),
plan()->exec_context()->memory_pool());
if (ErrorIfNotOk(maybe_batch.status())) return;
batches_.push_back(maybe_batch.MoveValueUnsafe());
}

if (input_counter_.Increment()) {
Finish();
}
}

protected:
Status DoFinish() {
Datum sorted;
{
std::unique_lock<std::mutex> lock(mutex_);
ARROW_ASSIGN_OR_RAISE(
auto table,
Table::FromRecordBatches(inputs_[0]->output_schema(), std::move(batches_)));
ARROW_ASSIGN_OR_RAISE(auto indices,
SortIndices(table, sort_options_, plan()->exec_context()));
ARROW_ASSIGN_OR_RAISE(sorted, Take(table, indices, TakeOptions::NoBoundsCheck(),
plan()->exec_context()));
}
TableBatchReader reader(*sorted.table());
while (true) {
std::shared_ptr<RecordBatch> batch;
RETURN_NOT_OK(reader.ReadNext(&batch));
if (!batch) break;
bool did_push = producer_.Push(ExecBatch(*batch));
if (!did_push) break; // producer_ was Closed already
}
return Status::OK();
}

void Finish() override {
Status st = DoFinish();
if (ErrorIfNotOk(st)) {
producer_.Push(std::move(st));
}
SinkNode::Finish();
}

private:
SortOptions sort_options_;
std::mutex mutex_;
std::vector<std::shared_ptr<RecordBatch>> batches_;
};

ExecFactoryRegistry::AddOnLoad kRegisterSink("sink", SinkNode::Make);
ExecFactoryRegistry::AddOnLoad kRegisterOrderBySink("order_by_sink",
OrderBySinkNode::Make);

} // namespace
} // namespace compute
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/arrow/compute/exec/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "arrow/compute/exec/util.h"

#include "arrow/compute/exec/exec_plan.h"
#include "arrow/table.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/ubsan.h"
Expand Down Expand Up @@ -296,5 +297,15 @@ Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector<ExecNode*>& inpu
return Status::OK();
}

Result<std::shared_ptr<Table>> TableFromExecBatches(
const std::shared_ptr<Schema>& schema, const std::vector<ExecBatch>& exec_batches) {
RecordBatchVector batches;
for (const auto& batch : exec_batches) {
ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema));
batches.push_back(std::move(rb));
}
return Table::FromRecordBatches(schema, batches);
}

} // namespace compute
} // namespace arrow
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/exec/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ ARROW_EXPORT
Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector<ExecNode*>& inputs,
int expected_num_inputs, const char* kind_name);

ARROW_EXPORT
Result<std::shared_ptr<Table>> TableFromExecBatches(
const std::shared_ptr<Schema>& schema, const std::vector<ExecBatch>& exec_batches);

class AtomicCounter {
public:
AtomicCounter() = default;
Expand Down