diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index a853a74362d..acc79bdfdde 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -23,6 +23,7 @@ #include #include "arrow/compute/api_aggregate.h" +#include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" #include "arrow/compute/exec/expression.h" #include "arrow/util/optional.h" @@ -111,5 +112,19 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions { std::function>()>* generator; }; +/// \brief Make a node which sorts rows passed through it +/// +/// All batches pushed to this node will be accumulated, then sorted, by the given +/// fields. Then sorted batches will be forwarded to the generator in sorted order. +class ARROW_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions { + public: + explicit OrderBySinkNodeOptions( + SortOptions sort_options, + std::function>()>* generator) + : SinkNodeOptions(generator), sort_options(std::move(sort_options)) {} + + SortOptions sort_options; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 3b3d39fd36a..f4d81ace040 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -25,7 +25,9 @@ #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/test_util.h" +#include "arrow/compute/exec/util.h" #include "arrow/record_batch.h" +#include "arrow/table.h" #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" @@ -36,6 +38,7 @@ #include "arrow/util/vector.h" using testing::ElementsAre; +using testing::ElementsAreArray; using testing::HasSubstr; using testing::Optional; using testing::UnorderedElementsAreArray; @@ -262,6 +265,7 @@ BatchesWithSchema MakeBasicBatches() { BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, int num_batches = 10, int batch_size = 4) { BatchesWithSchema out; + out.schema = schema; random::RandomArrayGenerator rng(42); out.batches.resize(num_batches); @@ -301,6 +305,36 @@ TEST(ExecPlanExecution, SourceSink) { } } +TEST(ExecPlanExecution, SourceOrderBy) { + std::vector expected = { + ExecBatchFromJSON({int32(), boolean()}, + "[[4, false], [5, null], [6, false], [7, false], [null, true]]")}; + for (bool slow : {false, true}) { + SCOPED_TRACE(slow ? "slowed" : "unslowed"); + + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel" : "single threaded"); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + AsyncGenerator> sink_gen; + + auto basic_data = MakeBasicBatches(); + + SortOptions options({SortKey("i32", SortOrder::Ascending)}); + ASSERT_OK(Declaration::Sequence( + { + {"source", SourceNodeOptions{basic_data.schema, + basic_data.gen(parallel, slow)}}, + {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}}, + }) + .AddToPlan(plan.get())); + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(ElementsAreArray(expected)))); + } + } +} + TEST(ExecPlanExecution, SourceSinkError) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); AsyncGenerator> sink_gen; @@ -355,6 +389,43 @@ TEST(ExecPlanExecution, StressSourceSink) { } } +TEST(ExecPlanExecution, StressSourceOrderBy) { + auto input_schema = schema({field("a", int32()), field("b", boolean())}); + for (bool slow : {false, true}) { + SCOPED_TRACE(slow ? "slowed" : "unslowed"); + + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel" : "single threaded"); + + int num_batches = slow && !parallel ? 30 : 300; + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + AsyncGenerator> sink_gen; + + auto random_data = MakeRandomBatches(input_schema, num_batches); + + SortOptions options({SortKey("a", SortOrder::Ascending)}); + ASSERT_OK(Declaration::Sequence( + { + {"source", SourceNodeOptions{random_data.schema, + random_data.gen(parallel, slow)}}, + {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}}, + }) + .AddToPlan(plan.get())); + + // Check that data is sorted appropriately + ASSERT_FINISHES_OK_AND_ASSIGN(auto exec_batches, + StartAndCollect(plan.get(), sink_gen)); + ASSERT_OK_AND_ASSIGN(auto actual, TableFromExecBatches(input_schema, exec_batches)); + ASSERT_OK_AND_ASSIGN(auto original, + TableFromExecBatches(input_schema, random_data.batches)); + ASSERT_OK_AND_ASSIGN(auto sort_indices, SortIndices(original, options)); + ASSERT_OK_AND_ASSIGN(auto expected, Take(original, sort_indices)); + AssertTablesEqual(*actual, *expected.table()); + } + } +} + TEST(ExecPlanExecution, StressSourceSinkStopped) { for (bool slow : {false, true}) { SCOPED_TRACE(slow ? "slowed" : "unslowed"); @@ -541,6 +612,45 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumFilter) { } } +TEST(ExecPlanExecution, SourceFilterProjectGroupedSumOrderBy) { + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel/merged" : "serial"); + + int batch_multiplicity = parallel ? 100 : 1; + auto input = MakeGroupableBatches(/*multiplicity=*/batch_multiplicity); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + AsyncGenerator> sink_gen; + + SortOptions options({SortKey("str", SortOrder::Descending)}); + ASSERT_OK( + Declaration::Sequence( + { + {"source", + SourceNodeOptions{input.schema, input.gen(parallel, /*slow=*/false)}}, + {"filter", + FilterNodeOptions{greater_equal(field_ref("i32"), literal(0))}}, + {"project", ProjectNodeOptions{{ + field_ref("str"), + call("multiply", {field_ref("i32"), literal(2)}), + }}}, + {"aggregate", AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}}, + /*targets=*/{"multiply(i32, 2)"}, + /*names=*/{"sum(multiply(i32, 2))"}, + /*keys=*/{"str"}}}, + {"filter", FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"), + literal(10 * batch_multiplicity))}}, + {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}}, + }) + .AddToPlan(plan.get())); + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(ElementsAreArray({ExecBatchFromJSON( + {int64(), utf8()}, parallel ? R"([[2000, "beta"], [3600, "alfa"]])" + : R"([[20, "beta"], [36, "alfa"]])")})))); + } +} + TEST(ExecPlanExecution, SourceScalarAggSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); AsyncGenerator> sink_gen; diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index e4a06e0d224..4d9f82e582b 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -20,6 +20,7 @@ #include +#include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/options.h" @@ -27,6 +28,7 @@ #include "arrow/compute/exec_internal.h" #include "arrow/datum.h" #include "arrow/result.h" +#include "arrow/table.h" #include "arrow/util/async_generator.h" #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" @@ -135,8 +137,8 @@ class SinkNode : public ExecNode { } } - private: - void Finish() { + protected: + virtual void Finish() { if (producer_.Close()) { finished_.MarkFinished(); } @@ -148,7 +150,82 @@ class SinkNode : public ExecNode { PushGenerator>::Producer producer_; }; +// A sink node that accumulates inputs, then sorts them before emitting them. +struct OrderBySinkNode final : public SinkNode { + OrderBySinkNode(ExecPlan* plan, std::vector inputs, SortOptions sort_options, + AsyncGenerator>* generator) + : SinkNode(plan, std::move(inputs), generator), + sort_options_(std::move(sort_options)) {} + + const char* kind_name() override { return "OrderBySinkNode"; } + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderBySinkNode")); + + const auto& sink_options = checked_cast(options); + return plan->EmplaceNode( + plan, std::move(inputs), sink_options.sort_options, sink_options.generator); + } + + void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + + // Accumulate data + { + std::unique_lock lock(mutex_); + auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(), + plan()->exec_context()->memory_pool()); + if (ErrorIfNotOk(maybe_batch.status())) return; + batches_.push_back(maybe_batch.MoveValueUnsafe()); + } + + if (input_counter_.Increment()) { + Finish(); + } + } + + protected: + Status DoFinish() { + Datum sorted; + { + std::unique_lock lock(mutex_); + ARROW_ASSIGN_OR_RAISE( + auto table, + Table::FromRecordBatches(inputs_[0]->output_schema(), std::move(batches_))); + ARROW_ASSIGN_OR_RAISE(auto indices, + SortIndices(table, sort_options_, plan()->exec_context())); + ARROW_ASSIGN_OR_RAISE(sorted, Take(table, indices, TakeOptions::NoBoundsCheck(), + plan()->exec_context())); + } + TableBatchReader reader(*sorted.table()); + while (true) { + std::shared_ptr batch; + RETURN_NOT_OK(reader.ReadNext(&batch)); + if (!batch) break; + bool did_push = producer_.Push(ExecBatch(*batch)); + if (!did_push) break; // producer_ was Closed already + } + return Status::OK(); + } + + void Finish() override { + Status st = DoFinish(); + if (ErrorIfNotOk(st)) { + producer_.Push(std::move(st)); + } + SinkNode::Finish(); + } + + private: + SortOptions sort_options_; + std::mutex mutex_; + std::vector> batches_; +}; + ExecFactoryRegistry::AddOnLoad kRegisterSink("sink", SinkNode::Make); +ExecFactoryRegistry::AddOnLoad kRegisterOrderBySink("order_by_sink", + OrderBySinkNode::Make); } // namespace } // namespace compute diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc index eecc617c9c0..aad6dc3d587 100644 --- a/cpp/src/arrow/compute/exec/util.cc +++ b/cpp/src/arrow/compute/exec/util.cc @@ -18,6 +18,7 @@ #include "arrow/compute/exec/util.h" #include "arrow/compute/exec/exec_plan.h" +#include "arrow/table.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/ubsan.h" @@ -296,5 +297,15 @@ Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector& inpu return Status::OK(); } +Result> TableFromExecBatches( + const std::shared_ptr& schema, const std::vector& exec_batches) { + RecordBatchVector batches; + for (const auto& batch : exec_batches) { + ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema)); + batches.push_back(std::move(rb)); + } + return Table::FromRecordBatches(schema, batches); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index b7cf0aeaa5e..8bd6a3c5d62 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -188,6 +188,10 @@ ARROW_EXPORT Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector& inputs, int expected_num_inputs, const char* kind_name); +ARROW_EXPORT +Result> TableFromExecBatches( + const std::shared_ptr& schema, const std::vector& exec_batches); + class AtomicCounter { public: AtomicCounter() = default;