diff --git a/cpp/examples/arrow/compute_register_example.cc b/cpp/examples/arrow/compute_register_example.cc index 6e5ff015387..0508cb3617c 100644 --- a/cpp/examples/arrow/compute_register_example.cc +++ b/cpp/examples/arrow/compute_register_example.cc @@ -98,7 +98,8 @@ class ExampleNode : public cp::ExecNode { void StopProducing(ExecNode* output) override { inputs_[0]->StopProducing(this); } void StopProducing() override { inputs_[0]->StopProducing(); } - void InputReceived(ExecNode* input, cp::ExecBatch batch) override {} + void InputReceived(ExecNode* input, + std::function()> task) override {} void ErrorReceived(ExecNode* input, arrow::Status error) override {} void InputFinished(ExecNode* input, int total_batches) override {} diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc index 59b2ff8b8af..128dfc87403 100644 --- a/cpp/src/arrow/compute/exec/aggregate_node.cc +++ b/cpp/src/arrow/compute/exec/aggregate_node.cc @@ -175,18 +175,21 @@ class ScalarAggregateNode : public ExecNode { return Status::OK(); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + void InputReceived(ExecNode* input, std::function()> task) override { DCHECK_EQ(input, inputs_[0]); auto thread_index = get_thread_index_(); - - if (ErrorIfNotOk(DoConsume(std::move(batch), thread_index))) return; + auto prev = task(); + if (!prev.ok()) { + ErrorIfNotOk(prev.status()); + return; + } + if (ErrorIfNotOk(DoConsume(prev.MoveValueUnsafe(), thread_index))) return; if (input_counter_.Increment()) { ErrorIfNotOk(Finish()); } } - void ErrorReceived(ExecNode* input, Status error) override { DCHECK_EQ(input, inputs_[0]); outputs_[0]->ErrorReceived(this, std::move(error)); @@ -235,17 +238,18 @@ class ScalarAggregateNode : public ExecNode { private: Status Finish() { - ExecBatch batch{{}, 1}; - batch.values.resize(kernels_.size()); - - for (size_t i = 0; i < kernels_.size(); ++i) { - KernelContext ctx{plan()->exec_context()}; - ARROW_ASSIGN_OR_RAISE(auto merged, ScalarAggregateKernel::MergeAll( - kernels_[i], &ctx, std::move(states_[i]))); - RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i])); - } - - outputs_[0]->InputReceived(this, std::move(batch)); + auto task = [this]() -> Result { + ExecBatch batch{{}, 1}; + batch.values.resize(kernels_.size()); + for (size_t i = 0; i < kernels_.size(); ++i) { + KernelContext ctx{plan()->exec_context()}; + ARROW_ASSIGN_OR_RAISE(auto merged, ScalarAggregateKernel::MergeAll( + kernels_[i], &ctx, std::move(states_[i]))); + RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i])); + } + return batch; + }; + outputs_[0]->InputReceived(this, std::move(task)); finished_.MarkFinished(); return Status::OK(); } @@ -452,8 +456,12 @@ class GroupByNode : public ExecNode { // bail if StopProducing was called if (finished_.is_finished()) return; - int64_t batch_size = output_batch_size(); - outputs_[0]->InputReceived(this, out_data_.Slice(batch_size * n, batch_size)); + auto task = [n, this]() -> Result { + int64_t batch_size = output_batch_size(); + return out_data_.Slice(batch_size * n, batch_size); + }; + + outputs_[0]->InputReceived(this, std::move(task)); if (output_counter_.Increment()) { finished_.MarkFinished(); @@ -483,13 +491,18 @@ class GroupByNode : public ExecNode { return Status::OK(); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + void InputReceived(ExecNode* input, std::function()> task) override { // bail if StopProducing was called if (finished_.is_finished()) return; DCHECK_EQ(input, inputs_[0]); - if (ErrorIfNotOk(Consume(std::move(batch)))) return; + auto prev = task(); + if (!prev.ok()) { + ErrorIfNotOk(prev.status()); + return; + } + if (ErrorIfNotOk(Consume(prev.MoveValueUnsafe()))) return; if (input_counter_.Increment()) { ErrorIfNotOk(OutputResult()); diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 7cd3011b8ab..fecf797f68b 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -350,29 +350,26 @@ void MapNode::StopProducing() { Future<> MapNode::finished() { return finished_; } -void MapNode::SubmitTask(std::function(ExecBatch)> map_fn, - ExecBatch batch) { +void MapNode::SubmitTask(std::function()> map_fn) { Status status; // This will be true if the node is stopped early due to an error or manual // cancellation if (input_counter_.Completed()) { return; } - auto task = [this, map_fn, batch]() { - auto guarantee = batch.guarantee; - auto output_batch = map_fn(std::move(batch)); + auto task_wrapper = [this, map_fn]() { + auto output_batch = map_fn(); if (ErrorIfNotOk(output_batch.status())) { return output_batch.status(); } - output_batch->guarantee = guarantee; - outputs_[0]->InputReceived(this, output_batch.MoveValueUnsafe()); + outputs_[0]->InputReceived(this, IdentityTask(output_batch.MoveValueUnsafe())); return Status::OK(); }; if (executor_) { - status = task_group_.AddTask([this, task]() -> Result> { - return this->executor_->Submit(this->stop_source_.token(), [this, task]() { - auto status = task(); + status = task_group_.AddTask([this, task_wrapper]() -> Result> { + return this->executor_->Submit(this->stop_source_.token(), [this, task_wrapper]() { + auto status = task_wrapper(); if (this->input_counter_.Increment()) { this->Finish(status); } @@ -380,7 +377,7 @@ void MapNode::SubmitTask(std::function(ExecBatch)> map_fn, }); }); } else { - status = task(); + status = task_wrapper(); if (input_counter_.Increment()) { this->Finish(status); } diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 4cb7fad009f..50fdb79765a 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -134,8 +134,9 @@ class ARROW_EXPORT ExecNode { /// - these are allowed to call back into PauseProducing(), ResumeProducing() /// and StopProducing() - /// Transfer input batch to ExecNode - virtual void InputReceived(ExecNode* input, ExecBatch batch) = 0; + /// Transfer the input task to ExecNode + virtual void InputReceived(ExecNode* input, + std::function()> task) = 0; /// Signal error to ExecNode virtual void ErrorReceived(ExecNode* input, Status error) = 0; @@ -226,6 +227,10 @@ class ARROW_EXPORT ExecNode { std::string ToString() const; protected: + static inline std::function()> IdentityTask(ExecBatch batch) { + return [batch]() -> Result { return batch; }; + } + ExecNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, std::shared_ptr output_schema, int num_outputs); @@ -277,7 +282,7 @@ class MapNode : public ExecNode { Future<> finished() override; protected: - void SubmitTask(std::function(ExecBatch)> map_fn, ExecBatch batch); + void SubmitTask(std::function()> map_fn); void Finish(Status finish_st = Status::OK()); diff --git a/cpp/src/arrow/compute/exec/filter_node.cc b/cpp/src/arrow/compute/exec/filter_node.cc index 2e6d974dc13..5aae9216306 100644 --- a/cpp/src/arrow/compute/exec/filter_node.cc +++ b/cpp/src/arrow/compute/exec/filter_node.cc @@ -89,13 +89,19 @@ class FilterNode : public MapNode { if (value.is_scalar()) continue; ARROW_ASSIGN_OR_RAISE(value, Filter(value, mask, FilterOptions::Defaults())); } - return ExecBatch::Make(std::move(values)); + + ARROW_ASSIGN_OR_RAISE(auto result, ExecBatch::Make(std::move(values))); + result.guarantee = target.guarantee; + return result; } - void InputReceived(ExecNode* input, ExecBatch batch) override { + void InputReceived(ExecNode* input, std::function()> task) override { DCHECK_EQ(input, inputs_[0]); - auto func = [this](ExecBatch batch) { return DoFilter(std::move(batch)); }; - this->SubmitTask(std::move(func), std::move(batch)); + auto func = [this, task]() -> Result { + ARROW_ASSIGN_OR_RAISE(auto batch, task()); + return DoFilter(std::move(batch)); + }; + this->SubmitTask(std::move(func)); } protected: diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 51e2e97cb1a..3865559127f 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -484,7 +484,7 @@ class HashJoinNode : public ExecNode { const char* kind_name() const override { return "HashJoinNode"; } - void InputReceived(ExecNode* input, ExecBatch batch) override { + void InputReceived(ExecNode* input, std::function()> task) override { ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); if (complete_.load()) { @@ -494,7 +494,13 @@ class HashJoinNode : public ExecNode { size_t thread_index = thread_indexer_(); int side = (input == inputs_[0]) ? 0 : 1; { - Status status = impl_->InputReceived(thread_index, side, std::move(batch)); + auto batch = task(); + if (!batch.ok()) { + StopProducing(); + ErrorIfNotOk(batch.status()); + return; + } + Status status = impl_->InputReceived(thread_index, side, batch.MoveValueUnsafe()); if (!status.ok()) { StopProducing(); ErrorIfNotOk(status); @@ -573,7 +579,7 @@ class HashJoinNode : public ExecNode { private: void OutputBatchCallback(ExecBatch batch) { - outputs_[0]->InputReceived(this, std::move(batch)); + outputs_[0]->InputReceived(this, IdentityTask(batch)); } void FinishedCallback(int64_t total_num_batches) { diff --git a/cpp/src/arrow/compute/exec/project_node.cc b/cpp/src/arrow/compute/exec/project_node.cc index c675acb3d98..4b8bea08800 100644 --- a/cpp/src/arrow/compute/exec/project_node.cc +++ b/cpp/src/arrow/compute/exec/project_node.cc @@ -85,13 +85,18 @@ class ProjectNode : public MapNode { ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(simplified_expr, target, plan()->exec_context())); } - return ExecBatch{std::move(values), target.length}; + auto result = ExecBatch{std::move(values), target.length}; + result.guarantee = target.guarantee; + return result; } - void InputReceived(ExecNode* input, ExecBatch batch) override { + void InputReceived(ExecNode* input, std::function()> task) override { DCHECK_EQ(input, inputs_[0]); - auto func = [this](ExecBatch batch) { return DoProject(std::move(batch)); }; - this->SubmitTask(std::move(func), std::move(batch)); + auto func = [this, task]() -> Result { + ARROW_ASSIGN_OR_RAISE(auto batch, task()); + return DoProject(std::move(batch)); + }; + this->SubmitTask(std::move(func)); } protected: diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index 1bb2680383c..e2c047b47da 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -95,10 +95,15 @@ class SinkNode : public ExecNode { Future<> finished() override { return finished_; } - void InputReceived(ExecNode* input, ExecBatch batch) override { + void InputReceived(ExecNode* input, std::function()> task) override { DCHECK_EQ(input, inputs_[0]); - bool did_push = producer_.Push(std::move(batch)); + auto batch = task(); + if (!batch.ok()) { + ErrorIfNotOk(batch.status()); + return; + } + bool did_push = producer_.Push(batch.MoveValueUnsafe()); if (!did_push) return; // producer_ was Closed already if (input_counter_.Increment()) { @@ -179,7 +184,7 @@ class ConsumingSinkNode : public ExecNode { Future<> finished() override { return finished_; } - void InputReceived(ExecNode* input, ExecBatch batch) override { + void InputReceived(ExecNode* input, std::function()> task) override { DCHECK_EQ(input, inputs_[0]); // This can happen if an error was received and the source hasn't yet stopped. Since @@ -188,7 +193,12 @@ class ConsumingSinkNode : public ExecNode { return; } - Status consumption_status = consumer_->Consume(std::move(batch)); + auto batch = task(); + if (!batch.ok()) { + ErrorIfNotOk(batch.status()); + return; + } + Status consumption_status = consumer_->Consume(batch.MoveValueUnsafe()); if (!consumption_status.ok()) { if (input_counter_.Cancel()) { Finish(std::move(consumption_status)); @@ -274,11 +284,15 @@ struct OrderBySinkNode final : public SinkNode { sink_options.backpressure); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + void InputReceived(ExecNode* input, std::function()> task) override { DCHECK_EQ(input, inputs_[0]); - - auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(), - plan()->exec_context()->memory_pool()); + auto batch = task(); + if (!batch.ok()) { + ErrorIfNotOk(batch.status()); + return; + } + auto maybe_batch = batch.ValueUnsafe().ToRecordBatch( + inputs_[0]->output_schema(), plan()->exec_context()->memory_pool()); if (ErrorIfNotOk(maybe_batch.status())) { StopProducing(); if (input_counter_.Cancel()) { diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 46bba5609d4..e00cbb7a0e6 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -61,7 +61,10 @@ struct SourceNode : ExecNode { [[noreturn]] static void NoInputs() { Unreachable("no inputs; this should never be called"); } - [[noreturn]] void InputReceived(ExecNode*, ExecBatch) override { NoInputs(); } + [[noreturn]] void InputReceived(ExecNode*, + std::function()>) override { + NoInputs(); + } [[noreturn]] void ErrorReceived(ExecNode*, Status) override { NoInputs(); } [[noreturn]] void InputFinished(ExecNode*, int) override { NoInputs(); } @@ -107,19 +110,19 @@ struct SourceNode : ExecNode { ExecBatch batch = std::move(*maybe_batch); if (executor) { - auto status = - task_group_.AddTask([this, executor, batch]() -> Result> { - return executor->Submit([=]() { - outputs_[0]->InputReceived(this, std::move(batch)); - return Status::OK(); - }); - }); + auto status = task_group_.AddTask([this, executor, + batch]() -> Result> { + return executor->Submit([=]() { + outputs_[0]->InputReceived(this, IdentityTask(std::move(batch))); + return Status::OK(); + }); + }); if (!status.ok()) { outputs_[0]->ErrorReceived(this, std::move(status)); return Break(total_batches); } } else { - outputs_[0]->InputReceived(this, std::move(batch)); + outputs_[0]->InputReceived(this, IdentityTask(std::move(batch))); } return Continue(); }, diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 64f3ec997c9..d137a6c947c 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -71,7 +71,7 @@ struct DummyNode : ExecNode { const char* kind_name() const override { return "Dummy"; } - void InputReceived(ExecNode* input, ExecBatch batch) override {} + void InputReceived(ExecNode*, std::function()>) override {} void ErrorReceived(ExecNode* input, Status error) override {} diff --git a/cpp/src/arrow/compute/exec/union_node.cc b/cpp/src/arrow/compute/exec/union_node.cc index fef2f4e1866..ae2d2989418 100644 --- a/cpp/src/arrow/compute/exec/union_node.cc +++ b/cpp/src/arrow/compute/exec/union_node.cc @@ -74,13 +74,18 @@ class UnionNode : public ExecNode { return plan->EmplaceNode(plan, std::move(inputs)); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + void InputReceived(ExecNode* input, std::function()> task) override { ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); if (finished_.is_finished()) { return; } - outputs_[0]->InputReceived(this, std::move(batch)); + auto batch = task(); + if (!batch.ok()) { + ErrorIfNotOk(batch.status()); + return; + } + outputs_[0]->InputReceived(this, IdentityTask(batch.MoveValueUnsafe())); if (batch_count_.Increment()) { finished_.MarkFinished(); } diff --git a/docs/source/cpp/streaming_execution.rst b/docs/source/cpp/streaming_execution.rst index 5864857c177..b76b6d4080c 100644 --- a/docs/source/cpp/streaming_execution.rst +++ b/docs/source/cpp/streaming_execution.rst @@ -95,8 +95,14 @@ through unchanged:: class PassthruNode : public ExecNode { public: // InputReceived is the main entry point for ExecNodes. It is invoked - // by an input of this node to push a batch here for processing. - void InputReceived(ExecNode* input, ExecBatch batch) override { + // by an input of this node to push a task here for processing. + // For non-terminating nodes (e.g. filter/project/etc.): the node can wrap + // its own work with the task (using function composition/fusing) and then + // call InputReceived on the downstream node. + // A "terminating node" (e.g. sink node / pipeline breaker) could then submit + // the task to a scheduler. + void InputReceived(ExecNode* input, + std::function()> task) override { // Since this is a passthru node we simply push the batch to our // only output here. outputs_[0]->InputReceived(this, batch);