diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index ec3e6fba230..a6a97553154 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -280,5 +280,19 @@ class ARROW_EXPORT SelectKSinkNodeOptions : public SinkNodeOptions { /// @} +/// \brief Adapt an Table as a sink node +/// +/// obtains the output of a execution plan to +/// a table pointer. +class ARROW_EXPORT TableSinkNodeOptions : public ExecNodeOptions { + public: + TableSinkNodeOptions(std::shared_ptr* output_table, + std::shared_ptr output_schema) + : output_table(output_table), output_schema(std::move(output_schema)) {} + + std::shared_ptr
* output_table; + std::shared_ptr output_schema; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 258238dbb81..b4b24e832ef 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -488,6 +488,38 @@ TEST(ExecPlanExecution, SourceConsumingSink) { } } +TEST(ExecPlanExecution, SourceTableConsumingSink) { + for (bool slow : {false, true}) { + SCOPED_TRACE(slow ? "slowed" : "unslowed"); + + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel" : "single threaded"); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + std::shared_ptr
out; + + auto basic_data = MakeBasicBatches(); + + TableSinkNodeOptions options{&out, basic_data.schema}; + + ASSERT_OK_AND_ASSIGN( + auto source, MakeExecNode("source", plan.get(), {}, + SourceNodeOptions(basic_data.schema, + basic_data.gen(parallel, slow)))); + ASSERT_OK(MakeExecNode("table_sink", plan.get(), {source}, options)); + ASSERT_OK(plan->StartProducing()); + // Source should finish fairly quickly + ASSERT_FINISHES_OK(source->finished()); + SleepABit(); + ASSERT_OK_AND_ASSIGN(auto actual, + TableFromExecBatches(basic_data.schema, basic_data.batches)); + ASSERT_EQ(5, out->num_rows()); + AssertTablesEqual(*actual, *out); + ASSERT_FINISHES_OK(plan->finished()); + } + } +} + TEST(ExecPlanExecution, ConsumingSinkError) { struct ConsumeErrorConsumer : public SinkNodeConsumer { Status Consume(ExecBatch batch) override { return Status::Invalid("XYZ"); } diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index f218cef9b5e..c9808c0a3e3 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -232,6 +232,51 @@ class ConsumingSinkNode : public ExecNode { std::shared_ptr consumer_; }; +/** + * @brief This node is an extension on ConsumingSinkNode + * to facilitate to get the output from an execution plan + * as a table. We define a custom SinkNodeConsumer to + * enable this functionality. + */ + +struct TableSinkNodeConsumer : public arrow::compute::SinkNodeConsumer { + public: + TableSinkNodeConsumer(std::shared_ptr
* out, + std::shared_ptr output_schema, MemoryPool* pool) + : out_(out), output_schema_(std::move(output_schema)), pool_(pool) {} + + Status Consume(ExecBatch batch) override { + std::lock_guard guard(consume_mutex_); + ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(output_schema_, pool_)); + batches_.push_back(rb); + return Status::OK(); + } + + Future<> Finish() override { + ARROW_ASSIGN_OR_RAISE(*out_, Table::FromRecordBatches(batches_)); + return Status::OK(); + } + + private: + std::shared_ptr
* out_; + std::shared_ptr output_schema_; + MemoryPool* pool_; + std::vector> batches_; + std::mutex consume_mutex_; +}; + +static Result MakeTableConsumingSinkNode( + compute::ExecPlan* plan, std::vector inputs, + const compute::ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "TableConsumingSinkNode")); + const auto& sink_options = checked_cast(options); + MemoryPool* pool = plan->exec_context()->memory_pool(); + auto tb_consumer = std::make_shared( + sink_options.output_table, sink_options.output_schema, pool); + auto consuming_sink_node_options = ConsumingSinkNodeOptions{tb_consumer}; + return MakeExecNode("consuming_sink", plan, inputs, consuming_sink_node_options); +} + // A sink node that accumulates inputs, then sorts them before emitting them. struct OrderBySinkNode final : public SinkNode { OrderBySinkNode(ExecPlan* plan, std::vector inputs, @@ -333,6 +378,7 @@ void RegisterSinkNode(ExecFactoryRegistry* registry) { DCHECK_OK(registry->AddFactory("order_by_sink", OrderBySinkNode::MakeSort)); DCHECK_OK(registry->AddFactory("consuming_sink", ConsumingSinkNode::Make)); DCHECK_OK(registry->AddFactory("sink", SinkNode::Make)); + DCHECK_OK(registry->AddFactory("table_sink", MakeTableConsumingSinkNode)); } } // namespace internal