Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/examples/arrow/compute_register_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ class ExampleNode : public cp::ExecNode {
void StopProducing(ExecNode* output) override { inputs_[0]->StopProducing(this); }
void StopProducing() override { inputs_[0]->StopProducing(); }

void InputReceived(ExecNode* input, cp::ExecBatch batch) override {}
void InputReceived(ExecNode* input,
std::function<arrow::Result<cp::ExecBatch>()> task) override {}
void ErrorReceived(ExecNode* input, arrow::Status error) override {}
void InputFinished(ExecNode* input, int total_batches) override {}

Expand Down
51 changes: 32 additions & 19 deletions cpp/src/arrow/compute/exec/aggregate_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,21 @@ class ScalarAggregateNode : public ExecNode {
return Status::OK();
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
void InputReceived(ExecNode* input, std::function<Result<ExecBatch>()> task) override {
DCHECK_EQ(input, inputs_[0]);

auto thread_index = get_thread_index_();

if (ErrorIfNotOk(DoConsume(std::move(batch), thread_index))) return;
auto prev = task();
if (!prev.ok()) {
ErrorIfNotOk(prev.status());
return;
}
if (ErrorIfNotOk(DoConsume(prev.MoveValueUnsafe(), thread_index))) return;
Comment on lines +182 to +187
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto prev = task();
if (!prev.ok()) {
ErrorIfNotOk(prev.status());
return;
}
if (ErrorIfNotOk(DoConsume(prev.MoveValueUnsafe(), thread_index))) return;
auto func = [this] (Result<ExecBatch> task) {
ARROW_ASSIGN_OR_RAISE(auto prev, task());
auto thread_index = get_thread_index_();
return DoConsume(prev.MoveValueUnsafe(), thread_index);
};
plan_->scheduler()->SubmitTask(std::move(func));

This is what I'm thinking pipeline breakers would look like.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plan_->scheduler()->SubmitTask(std::move(func));

Yes that is the idea, but this PR is to enable that construction later, this PR is not going to define any scheduler or submitting logic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we aren't going to address this now let's make another JIRA (taskify 3?) Something like, "Fix logic in existing nodes so that pipeline breakers submit and non-breakers forward" and then add a comment in all of these spots along the lines of...

// This node should be forwarding the task downstream but that will be addressed in ARROW-XYZ


if (input_counter_.Increment()) {
ErrorIfNotOk(Finish());
}
}

void ErrorReceived(ExecNode* input, Status error) override {
DCHECK_EQ(input, inputs_[0]);
outputs_[0]->ErrorReceived(this, std::move(error));
Expand Down Expand Up @@ -235,17 +238,18 @@ class ScalarAggregateNode : public ExecNode {

private:
Status Finish() {
ExecBatch batch{{}, 1};
batch.values.resize(kernels_.size());

for (size_t i = 0; i < kernels_.size(); ++i) {
KernelContext ctx{plan()->exec_context()};
ARROW_ASSIGN_OR_RAISE(auto merged, ScalarAggregateKernel::MergeAll(
kernels_[i], &ctx, std::move(states_[i])));
RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i]));
}

outputs_[0]->InputReceived(this, std::move(batch));
auto task = [this]() -> Result<ExecBatch> {
ExecBatch batch{{}, 1};
batch.values.resize(kernels_.size());
for (size_t i = 0; i < kernels_.size(); ++i) {
KernelContext ctx{plan()->exec_context()};
ARROW_ASSIGN_OR_RAISE(auto merged, ScalarAggregateKernel::MergeAll(
kernels_[i], &ctx, std::move(states_[i])));
RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i]));
}
return batch;
};
outputs_[0]->InputReceived(this, std::move(task));
Comment on lines +241 to +252
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good 👍

finished_.MarkFinished();
return Status::OK();
}
Expand Down Expand Up @@ -452,8 +456,12 @@ class GroupByNode : public ExecNode {
// bail if StopProducing was called
if (finished_.is_finished()) return;

int64_t batch_size = output_batch_size();
outputs_[0]->InputReceived(this, out_data_.Slice(batch_size * n, batch_size));
auto task = [n, this]() -> Result<ExecBatch> {
int64_t batch_size = output_batch_size();
return out_data_.Slice(batch_size * n, batch_size);
};

outputs_[0]->InputReceived(this, std::move(task));

if (output_counter_.Increment()) {
finished_.MarkFinished();
Expand Down Expand Up @@ -483,13 +491,18 @@ class GroupByNode : public ExecNode {
return Status::OK();
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
void InputReceived(ExecNode* input, std::function<Result<ExecBatch>()> task) override {
// bail if StopProducing was called
if (finished_.is_finished()) return;

DCHECK_EQ(input, inputs_[0]);

if (ErrorIfNotOk(Consume(std::move(batch)))) return;
auto prev = task();
if (!prev.ok()) {
ErrorIfNotOk(prev.status());
return;
}
if (ErrorIfNotOk(Consume(prev.MoveValueUnsafe()))) return;

if (input_counter_.Increment()) {
ErrorIfNotOk(OutputResult());
Expand Down
19 changes: 8 additions & 11 deletions cpp/src/arrow/compute/exec/exec_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,37 +350,34 @@ void MapNode::StopProducing() {

Future<> MapNode::finished() { return finished_; }

void MapNode::SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn,
ExecBatch batch) {
void MapNode::SubmitTask(std::function<Result<ExecBatch>()> map_fn) {
Status status;
// This will be true if the node is stopped early due to an error or manual
// cancellation
if (input_counter_.Completed()) {
return;
}
auto task = [this, map_fn, batch]() {
auto guarantee = batch.guarantee;
auto output_batch = map_fn(std::move(batch));
auto task_wrapper = [this, map_fn]() {
auto output_batch = map_fn();
if (ErrorIfNotOk(output_batch.status())) {
return output_batch.status();
}
output_batch->guarantee = guarantee;
outputs_[0]->InputReceived(this, output_batch.MoveValueUnsafe());
outputs_[0]->InputReceived(this, IdentityTask(output_batch.MoveValueUnsafe()));
return Status::OK();
};

if (executor_) {
status = task_group_.AddTask([this, task]() -> Result<Future<>> {
return this->executor_->Submit(this->stop_source_.token(), [this, task]() {
auto status = task();
status = task_group_.AddTask([this, task_wrapper]() -> Result<Future<>> {
return this->executor_->Submit(this->stop_source_.token(), [this, task_wrapper]() {
auto status = task_wrapper();
if (this->input_counter_.Increment()) {
this->Finish(status);
}
return status;
});
});
} else {
status = task();
status = task_wrapper();
if (input_counter_.Increment()) {
this->Finish(status);
}
Expand Down
11 changes: 8 additions & 3 deletions cpp/src/arrow/compute/exec/exec_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ class ARROW_EXPORT ExecNode {
/// - these are allowed to call back into PauseProducing(), ResumeProducing()
/// and StopProducing()

/// Transfer input batch to ExecNode
virtual void InputReceived(ExecNode* input, ExecBatch batch) = 0;
/// Transfer the input task to ExecNode
virtual void InputReceived(ExecNode* input,
std::function<Result<ExecBatch>()> task) = 0;

/// Signal error to ExecNode
virtual void ErrorReceived(ExecNode* input, Status error) = 0;
Expand Down Expand Up @@ -226,6 +227,10 @@ class ARROW_EXPORT ExecNode {
std::string ToString() const;

protected:
static inline std::function<Result<ExecBatch>()> IdentityTask(ExecBatch batch) {
return [batch]() -> Result<ExecBatch> { return batch; };
}

ExecNode(ExecPlan* plan, NodeVector inputs, std::vector<std::string> input_labels,
std::shared_ptr<Schema> output_schema, int num_outputs);

Expand Down Expand Up @@ -277,7 +282,7 @@ class MapNode : public ExecNode {
Future<> finished() override;

protected:
void SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn, ExecBatch batch);
void SubmitTask(std::function<Result<ExecBatch>()> map_fn);

void Finish(Status finish_st = Status::OK());

Expand Down
14 changes: 10 additions & 4 deletions cpp/src/arrow/compute/exec/filter_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,19 @@ class FilterNode : public MapNode {
if (value.is_scalar()) continue;
ARROW_ASSIGN_OR_RAISE(value, Filter(value, mask, FilterOptions::Defaults()));
}
return ExecBatch::Make(std::move(values));

ARROW_ASSIGN_OR_RAISE(auto result, ExecBatch::Make(std::move(values)));
result.guarantee = target.guarantee;
return result;
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
void InputReceived(ExecNode* input, std::function<Result<ExecBatch>()> task) override {
DCHECK_EQ(input, inputs_[0]);
auto func = [this](ExecBatch batch) { return DoFilter(std::move(batch)); };
this->SubmitTask(std::move(func), std::move(batch));
auto func = [this, task]() -> Result<ExecBatch> {
ARROW_ASSIGN_OR_RAISE(auto batch, task());
return DoFilter(std::move(batch));
};
this->SubmitTask(std::move(func));
}

protected:
Expand Down
12 changes: 9 additions & 3 deletions cpp/src/arrow/compute/exec/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ class HashJoinNode : public ExecNode {

const char* kind_name() const override { return "HashJoinNode"; }

void InputReceived(ExecNode* input, ExecBatch batch) override {
void InputReceived(ExecNode* input, std::function<Result<ExecBatch>()> task) override {
ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end());

if (complete_.load()) {
Expand All @@ -494,7 +494,13 @@ class HashJoinNode : public ExecNode {
size_t thread_index = thread_indexer_();
int side = (input == inputs_[0]) ? 0 : 1;
{
Status status = impl_->InputReceived(thread_index, side, std::move(batch));
auto batch = task();
if (!batch.ok()) {
StopProducing();
ErrorIfNotOk(batch.status());
return;
}
Status status = impl_->InputReceived(thread_index, side, batch.MoveValueUnsafe());
if (!status.ok()) {
StopProducing();
ErrorIfNotOk(status);
Expand Down Expand Up @@ -573,7 +579,7 @@ class HashJoinNode : public ExecNode {

private:
void OutputBatchCallback(ExecBatch batch) {
outputs_[0]->InputReceived(this, std::move(batch));
outputs_[0]->InputReceived(this, IdentityTask(batch));
}

void FinishedCallback(int64_t total_num_batches) {
Expand Down
13 changes: 9 additions & 4 deletions cpp/src/arrow/compute/exec/project_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,18 @@ class ProjectNode : public MapNode {
ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(simplified_expr, target,
plan()->exec_context()));
}
return ExecBatch{std::move(values), target.length};
auto result = ExecBatch{std::move(values), target.length};
result.guarantee = target.guarantee;
return result;
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
void InputReceived(ExecNode* input, std::function<Result<ExecBatch>()> task) override {
DCHECK_EQ(input, inputs_[0]);
auto func = [this](ExecBatch batch) { return DoProject(std::move(batch)); };
this->SubmitTask(std::move(func), std::move(batch));
auto func = [this, task]() -> Result<ExecBatch> {
ARROW_ASSIGN_OR_RAISE(auto batch, task());
return DoProject(std::move(batch));
};
this->SubmitTask(std::move(func));
}

protected:
Expand Down
30 changes: 22 additions & 8 deletions cpp/src/arrow/compute/exec/sink_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,15 @@ class SinkNode : public ExecNode {

Future<> finished() override { return finished_; }

void InputReceived(ExecNode* input, ExecBatch batch) override {
void InputReceived(ExecNode* input, std::function<Result<ExecBatch>()> task) override {
DCHECK_EQ(input, inputs_[0]);

bool did_push = producer_.Push(std::move(batch));
auto batch = task();
if (!batch.ok()) {
ErrorIfNotOk(batch.status());
return;
}
bool did_push = producer_.Push(batch.MoveValueUnsafe());
if (!did_push) return; // producer_ was Closed already

if (input_counter_.Increment()) {
Expand Down Expand Up @@ -179,7 +184,7 @@ class ConsumingSinkNode : public ExecNode {

Future<> finished() override { return finished_; }

void InputReceived(ExecNode* input, ExecBatch batch) override {
void InputReceived(ExecNode* input, std::function<Result<ExecBatch>()> task) override {
DCHECK_EQ(input, inputs_[0]);

// This can happen if an error was received and the source hasn't yet stopped. Since
Expand All @@ -188,7 +193,12 @@ class ConsumingSinkNode : public ExecNode {
return;
}

Status consumption_status = consumer_->Consume(std::move(batch));
auto batch = task();
if (!batch.ok()) {
ErrorIfNotOk(batch.status());
return;
}
Status consumption_status = consumer_->Consume(batch.MoveValueUnsafe());
if (!consumption_status.ok()) {
if (input_counter_.Cancel()) {
Finish(std::move(consumption_status));
Expand Down Expand Up @@ -274,11 +284,15 @@ struct OrderBySinkNode final : public SinkNode {
sink_options.backpressure);
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
void InputReceived(ExecNode* input, std::function<Result<ExecBatch>()> task) override {
DCHECK_EQ(input, inputs_[0]);

auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(),
plan()->exec_context()->memory_pool());
auto batch = task();
if (!batch.ok()) {
ErrorIfNotOk(batch.status());
return;
}
auto maybe_batch = batch.ValueUnsafe().ToRecordBatch(
inputs_[0]->output_schema(), plan()->exec_context()->memory_pool());
if (ErrorIfNotOk(maybe_batch.status())) {
StopProducing();
if (input_counter_.Cancel()) {
Expand Down
21 changes: 12 additions & 9 deletions cpp/src/arrow/compute/exec/source_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ struct SourceNode : ExecNode {
[[noreturn]] static void NoInputs() {
Unreachable("no inputs; this should never be called");
}
[[noreturn]] void InputReceived(ExecNode*, ExecBatch) override { NoInputs(); }
[[noreturn]] void InputReceived(ExecNode*,
std::function<Result<ExecBatch>()>) override {
NoInputs();
}
[[noreturn]] void ErrorReceived(ExecNode*, Status) override { NoInputs(); }
[[noreturn]] void InputFinished(ExecNode*, int) override { NoInputs(); }

Expand Down Expand Up @@ -107,19 +110,19 @@ struct SourceNode : ExecNode {
ExecBatch batch = std::move(*maybe_batch);

if (executor) {
auto status =
task_group_.AddTask([this, executor, batch]() -> Result<Future<>> {
return executor->Submit([=]() {
outputs_[0]->InputReceived(this, std::move(batch));
return Status::OK();
});
});
auto status = task_group_.AddTask([this, executor,
batch]() -> Result<Future<>> {
return executor->Submit([=]() {
outputs_[0]->InputReceived(this, IdentityTask(std::move(batch)));
return Status::OK();
});
});
if (!status.ok()) {
outputs_[0]->ErrorReceived(this, std::move(status));
return Break(total_batches);
}
} else {
outputs_[0]->InputReceived(this, std::move(batch));
outputs_[0]->InputReceived(this, IdentityTask(std::move(batch)));
Comment on lines -110 to +125
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would actually not create a task but forward to downstream like filter/project.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what will this eventually look like? If we assume we don't know how many batches a scanner will emit then how many "scan tasks" do we submit individually? I suppose we can always "over-submit" and then the final tasks will just abandon themselves if the scanner is finished. Could this be another spot for backpressure? I don't think we have to solve all of these problems right now.

}
return Continue();
},
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/exec/test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct DummyNode : ExecNode {

const char* kind_name() const override { return "Dummy"; }

void InputReceived(ExecNode* input, ExecBatch batch) override {}
void InputReceived(ExecNode*, std::function<Result<ExecBatch>()>) override {}

void ErrorReceived(ExecNode* input, Status error) override {}

Expand Down
9 changes: 7 additions & 2 deletions cpp/src/arrow/compute/exec/union_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,18 @@ class UnionNode : public ExecNode {
return plan->EmplaceNode<UnionNode>(plan, std::move(inputs));
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
void InputReceived(ExecNode* input, std::function<Result<ExecBatch>()> task) override {
ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end());

if (finished_.is_finished()) {
return;
}
outputs_[0]->InputReceived(this, std::move(batch));
auto batch = task();
if (!batch.ok()) {
ErrorIfNotOk(batch.status());
return;
}
outputs_[0]->InputReceived(this, IdentityTask(batch.MoveValueUnsafe()));
if (batch_count_.Increment()) {
finished_.MarkFinished();
}
Expand Down
Loading