Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions cpp/src/arrow/compute/exec/exec_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,107 @@ bool ExecNode::ErrorIfNotOk(Status status) {
return true;
}

MapNode::MapNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema, bool async_mode)
: ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"},
std::move(output_schema),
/*num_outputs=*/1) {
if (async_mode) {
executor_ = plan_->exec_context()->executor();
} else {
executor_ = nullptr;
}
}

void MapNode::ErrorReceived(ExecNode* input, Status error) {
DCHECK_EQ(input, inputs_[0]);
outputs_[0]->ErrorReceived(this, std::move(error));
}

void MapNode::InputFinished(ExecNode* input, int total_batches) {
DCHECK_EQ(input, inputs_[0]);
outputs_[0]->InputFinished(this, total_batches);
if (input_counter_.SetTotal(total_batches)) {
this->Finish();
}
}

Status MapNode::StartProducing() { return Status::OK(); }

void MapNode::PauseProducing(ExecNode* output) {}

void MapNode::ResumeProducing(ExecNode* output) {}

void MapNode::StopProducing(ExecNode* output) {
DCHECK_EQ(output, outputs_[0]);
StopProducing();
}

void MapNode::StopProducing() {
if (executor_) {
this->stop_source_.RequestStop();
}
if (input_counter_.Cancel()) {
this->Finish();
}
inputs_[0]->StopProducing(this);
}

Future<> MapNode::finished() { return finished_; }

void MapNode::SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn,
ExecBatch batch) {
Status status;
if (finished_.is_finished()) {
return;
}
auto task = [this, map_fn, batch]() {
auto guarantee = batch.guarantee;
auto output_batch = map_fn(std::move(batch));
if (ErrorIfNotOk(output_batch.status())) {
return output_batch.status();
}
output_batch->guarantee = guarantee;
outputs_[0]->InputReceived(this, output_batch.MoveValueUnsafe());
return Status::OK();
};

if (executor_) {
status = task_group_.AddTask([this, task]() -> Result<Future<>> {
return this->executor_->Submit(this->stop_source_.token(), [this, task]() {
auto status = task();
if (this->input_counter_.Increment()) {
this->Finish(status);
}
return status;
});
});
} else {
status = task();
if (input_counter_.Increment()) {
this->Finish(status);
}
}
if (!status.ok()) {
if (input_counter_.Cancel()) {
this->Finish(status);
}
inputs_[0]->StopProducing(this);
return;
}
}

void MapNode::Finish(Status finish_st /*= Status::OK()*/) {
if (executor_) {
task_group_.End().AddCallback([this, finish_st](const Status& st) {
Status final_status = finish_st & st;
this->finished_.MarkFinished(final_status);
});
} else {
this->finished_.MarkFinished(finish_st);
}
}

std::shared_ptr<RecordBatchReader> MakeGeneratorReader(
std::shared_ptr<Schema> schema,
std::function<Future<util::optional<ExecBatch>>()> gen, MemoryPool* pool) {
Expand Down
54 changes: 54 additions & 0 deletions cpp/src/arrow/compute/exec/exec_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@
#include <vector>

#include "arrow/compute/exec.h"
#include "arrow/compute/exec/util.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/type_fwd.h"
#include "arrow/util/async_util.h"
#include "arrow/util/cancel.h"
#include "arrow/util/macros.h"
#include "arrow/util/optional.h"
#include "arrow/util/visibility.h"

namespace arrow {

namespace compute {

class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this<ExecPlan> {
Expand Down Expand Up @@ -243,6 +247,56 @@ class ARROW_EXPORT ExecNode {
NodeVector outputs_;
};

/// \brief MapNode is an ExecNode type class which process a task like filter/project
/// (See SubmitTask method) to each given ExecBatch object, which have one input, one
/// output, and are pure functions on the input
///
/// A simple parallel runner is created with a "map_fn" which is just a function that
/// takes a batch in and returns a batch. This simple parallel runner also needs an
/// executor (use simple synchronous runner if there is no executor)

class MapNode : public ExecNode {
public:
MapNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema, bool async_mode);

void ErrorReceived(ExecNode* input, Status error) override;

void InputFinished(ExecNode* input, int total_batches) override;

Status StartProducing() override;

void PauseProducing(ExecNode* output) override;

void ResumeProducing(ExecNode* output) override;

void StopProducing(ExecNode* output) override;

void StopProducing() override;

Future<> finished() override;

protected:
void SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn, ExecBatch batch);

void Finish(Status finish_st = Status::OK());

protected:
// Counter for the number of batches received
AtomicCounter input_counter_;

// Future to sync finished
Future<> finished_ = Future<>::Make();

// The task group for the corresponding batches
util::AsyncTaskGroup task_group_;

::arrow::internal::Executor* executor_;

// Variable used to cancel remaining tasks in the executor
StopSource stop_source_;
};

/// \brief An extensible registry for factories of ExecNodes
class ARROW_EXPORT ExecFactoryRegistry {
public:
Expand Down
49 changes: 7 additions & 42 deletions cpp/src/arrow/compute/exec/filter_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,23 @@
#include "arrow/compute/exec.h"
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/util.h"
#include "arrow/datum.h"
#include "arrow/result.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/future.h"
#include "arrow/util/logging.h"

namespace arrow {

using internal::checked_cast;

namespace compute {
namespace {

class FilterNode : public ExecNode {
class FilterNode : public MapNode {
public:
FilterNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema, Expression filter)
: ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"},
std::move(output_schema),
/*num_outputs=*/1),
std::shared_ptr<Schema> output_schema, Expression filter, bool async_mode)
: MapNode(plan, std::move(inputs), std::move(output_schema), async_mode),
filter_(std::move(filter)) {}

static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
Expand All @@ -61,9 +57,9 @@ class FilterNode : public ExecNode {
filter_expression.ToString(), " evaluates to ",
filter_expression.type()->ToString());
}

return plan->EmplaceNode<FilterNode>(plan, std::move(inputs), std::move(schema),
std::move(filter_expression));
std::move(filter_expression),
filter_options.async_mode);
}

const char* kind_name() const override { return "FilterNode"; }
Expand Down Expand Up @@ -98,50 +94,19 @@ class FilterNode : public ExecNode {

void InputReceived(ExecNode* input, ExecBatch batch) override {
DCHECK_EQ(input, inputs_[0]);

auto maybe_filtered = DoFilter(std::move(batch));
if (ErrorIfNotOk(maybe_filtered.status())) return;

maybe_filtered->guarantee = batch.guarantee;
outputs_[0]->InputReceived(this, maybe_filtered.MoveValueUnsafe());
}

void ErrorReceived(ExecNode* input, Status error) override {
DCHECK_EQ(input, inputs_[0]);
outputs_[0]->ErrorReceived(this, std::move(error));
}

void InputFinished(ExecNode* input, int total_batches) override {
DCHECK_EQ(input, inputs_[0]);
outputs_[0]->InputFinished(this, total_batches);
auto func = [this](ExecBatch batch) { return DoFilter(std::move(batch)); };
this->SubmitTask(std::move(func), std::move(batch));
}

Status StartProducing() override { return Status::OK(); }

void PauseProducing(ExecNode* output) override {}

void ResumeProducing(ExecNode* output) override {}

void StopProducing(ExecNode* output) override {
DCHECK_EQ(output, outputs_[0]);
StopProducing();
}

void StopProducing() override { inputs_[0]->StopProducing(this); }

Future<> finished() override { return inputs_[0]->finished(); }

protected:
std::string ToStringExtra() const override { return "filter=" + filter_.ToString(); }

private:
Expression filter_;
};

} // namespace

namespace internal {

void RegisterFilterNode(ExecFactoryRegistry* registry) {
DCHECK_OK(registry->AddFactory("filter", FilterNode::Make));
}
Expand Down
12 changes: 8 additions & 4 deletions cpp/src/arrow/compute/exec/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ class ARROW_EXPORT SourceNodeOptions : public ExecNodeOptions {
/// excluded in the batch emitted by this node.
class ARROW_EXPORT FilterNodeOptions : public ExecNodeOptions {
public:
explicit FilterNodeOptions(Expression filter_expression)
: filter_expression(std::move(filter_expression)) {}
explicit FilterNodeOptions(Expression filter_expression, bool async_mode = true)
: filter_expression(std::move(filter_expression)), async_mode(async_mode) {}

Expression filter_expression;
bool async_mode;
};

/// \brief Make a node which executes expressions on input batches, producing new batches.
Expand All @@ -73,11 +74,14 @@ class ARROW_EXPORT FilterNodeOptions : public ExecNodeOptions {
class ARROW_EXPORT ProjectNodeOptions : public ExecNodeOptions {
public:
explicit ProjectNodeOptions(std::vector<Expression> expressions,
std::vector<std::string> names = {})
: expressions(std::move(expressions)), names(std::move(names)) {}
std::vector<std::string> names = {}, bool async_mode = true)
: expressions(std::move(expressions)),
names(std::move(names)),
async_mode(async_mode) {}

std::vector<Expression> expressions;
std::vector<std::string> names;
bool async_mode;
};

/// \brief Make a node which aggregates input batches, optionally grouped by keys.
Expand Down
46 changes: 8 additions & 38 deletions cpp/src/arrow/compute/exec/project_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,12 @@ using internal::checked_cast;
namespace compute {
namespace {

class ProjectNode : public ExecNode {
class ProjectNode : public MapNode {
public:
ProjectNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema, std::vector<Expression> exprs)
: ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"},
std::move(output_schema),
/*num_outputs=*/1),
std::shared_ptr<Schema> output_schema, std::vector<Expression> exprs,
bool async_mode)
: MapNode(plan, std::move(inputs), std::move(output_schema), async_mode),
exprs_(std::move(exprs)) {}

static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
Expand All @@ -70,9 +69,9 @@ class ProjectNode : public ExecNode {
fields[i] = field(std::move(names[i]), expr.type());
++i;
}

return plan->EmplaceNode<ProjectNode>(plan, std::move(inputs),
schema(std::move(fields)), std::move(exprs));
schema(std::move(fields)), std::move(exprs),
project_options.async_mode);
}

const char* kind_name() const override { return "ProjectNode"; }
Expand All @@ -91,39 +90,10 @@ class ProjectNode : public ExecNode {

void InputReceived(ExecNode* input, ExecBatch batch) override {
DCHECK_EQ(input, inputs_[0]);

auto maybe_projected = DoProject(std::move(batch));
if (ErrorIfNotOk(maybe_projected.status())) return;

maybe_projected->guarantee = batch.guarantee;
outputs_[0]->InputReceived(this, maybe_projected.MoveValueUnsafe());
}

void ErrorReceived(ExecNode* input, Status error) override {
DCHECK_EQ(input, inputs_[0]);
outputs_[0]->ErrorReceived(this, std::move(error));
}

void InputFinished(ExecNode* input, int total_batches) override {
DCHECK_EQ(input, inputs_[0]);
outputs_[0]->InputFinished(this, total_batches);
}

Status StartProducing() override { return Status::OK(); }

void PauseProducing(ExecNode* output) override {}

void ResumeProducing(ExecNode* output) override {}

void StopProducing(ExecNode* output) override {
DCHECK_EQ(output, outputs_[0]);
StopProducing();
auto func = [this](ExecBatch batch) { return DoProject(std::move(batch)); };
this->SubmitTask(std::move(func), std::move(batch));
}

void StopProducing() override { inputs_[0]->StopProducing(this); }

Future<> finished() override { return inputs_[0]->finished(); }

protected:
std::string ToStringExtra() const override {
std::stringstream ss;
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/compute/exec/sink_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ struct OrderBySinkNode final : public SinkNode {
plan()->exec_context()->memory_pool());
if (ErrorIfNotOk(maybe_batch.status())) {
StopProducing();
bool cancelled = input_counter_.Cancel();
DCHECK(cancelled);
finished_.MarkFinished(maybe_batch.status());
if (input_counter_.Cancel()) {
finished_.MarkFinished(maybe_batch.status());
}
return;
}
auto record_batch = maybe_batch.MoveValueUnsafe();
Expand Down
Loading