diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index e176c701b65..04225e31343 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -559,6 +559,71 @@ TEST(ExecPlanExecution, SourceTableConsumingSink) { } } +TEST(ExecPlanExecution, SinkNodeOutputSchema) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + AsyncGenerator> sink_gen; + + auto basic_data = MakeBasicBatches(); + + ASSERT_OK( + Declaration::Sequence({ + {"source", SourceNodeOptions{basic_data.schema, + basic_data.gen(true, true)}}, + {"sink", SinkNodeOptions{&sink_gen}}, + }) + .AddToPlan(plan.get())); + ASSERT_EQ(plan->sources()[0]->output_schema(), plan->sinks()[0]->output_schema()); +} + +TEST(ExecPlanExecution, ConsumingSinkNodeOutputSchema) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + std::atomic batches_seen{0}; + Future<> finish = Future<>::Make(); + struct TestConsumer : public SinkNodeConsumer { + TestConsumer(std::atomic* batches_seen, Future<> finish) + : batches_seen(batches_seen), finish(std::move(finish)) {} + + Status Consume(ExecBatch batch) override { + (*batches_seen)++; + return Status::OK(); + } + + Future<> Finish() override { return finish; } + + std::atomic* batches_seen; + Future<> finish; + }; + std::shared_ptr consumer = + std::make_shared(&batches_seen, finish); + + auto basic_data = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN( + auto source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions(basic_data.schema, basic_data.gen(true, true)))); + ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source}, + ConsumingSinkNodeOptions(consumer))); + ASSERT_EQ(plan->sources()[0]->output_schema(), plan->sinks()[0]->output_schema()); +} + +TEST(ExecPlanExecution, OrderBySinkNodeOutputSchema) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + auto input_schema = schema({field("a", int32()), field("b", boolean())}); + AsyncGenerator> sink_gen; + + auto random_data = MakeRandomBatches(input_schema, 10); + + SortOptions options({SortKey("a", SortOrder::Ascending)}); + ASSERT_OK(Declaration::Sequence( + { + {"source", + SourceNodeOptions{random_data.schema, random_data.gen(true, true)}}, + {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}}, + }) + .AddToPlan(plan.get())); + ASSERT_EQ(plan->sources()[0]->output_schema(), plan->sinks()[0]->output_schema()); +} + TEST(ExecPlanExecution, ConsumingSinkError) { struct ConsumeErrorConsumer : public SinkNodeConsumer { Status Consume(ExecBatch batch) override { return Status::Invalid("XYZ"); } diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index 13564c736b5..3a31288e6b6 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -49,19 +49,20 @@ namespace { class SinkNode : public ExecNode { public: SinkNode(ExecPlan* plan, std::vector inputs, + std::shared_ptr output_schema, AsyncGenerator>* generator, util::BackpressureOptions backpressure) - : ExecNode(plan, std::move(inputs), {"collected"}, {}, + : ExecNode(plan, std::move(inputs), {"collected"}, std::move(output_schema), /*num_outputs=*/0), producer_(MakeProducer(generator, std::move(backpressure))) {} static Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "SinkNode")); - + auto schema = inputs[0]->output_schema(); const auto& sink_options = checked_cast(options); - return plan->EmplaceNode(plan, std::move(inputs), sink_options.generator, - sink_options.backpressure); + return plan->EmplaceNode(plan, std::move(inputs), std::move(schema), + sink_options.generator, sink_options.backpressure); } static PushGenerator>::Producer MakeProducer( @@ -157,18 +158,19 @@ class SinkNode : public ExecNode { class ConsumingSinkNode : public ExecNode { public: ConsumingSinkNode(ExecPlan* plan, std::vector inputs, + std::shared_ptr output_schema, std::shared_ptr consumer) - : ExecNode(plan, std::move(inputs), {"to_consume"}, {}, + : ExecNode(plan, std::move(inputs), {"to_consume"}, std::move(output_schema), /*num_outputs=*/0), consumer_(std::move(consumer)) {} static Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "SinkNode")); - + auto schema = inputs[0]->output_schema(); const auto& sink_options = checked_cast(options); - return plan->EmplaceNode(plan, std::move(inputs), - std::move(sink_options.consumer)); + return plan->EmplaceNode( + plan, std::move(inputs), std::move(schema), std::move(sink_options.consumer)); } const char* kind_name() const override { return "ConsumingSinkNode"; } @@ -307,10 +309,12 @@ static Result MakeTableConsumingSinkNode( // A sink node that accumulates inputs, then sorts them before emitting them. struct OrderBySinkNode final : public SinkNode { OrderBySinkNode(ExecPlan* plan, std::vector inputs, + std::shared_ptr output_schema, std::unique_ptr impl, AsyncGenerator>* generator, util::BackpressureOptions backpressure) - : SinkNode(plan, std::move(inputs), generator, std::move(backpressure)), + : SinkNode(plan, std::move(inputs), std::move(output_schema), generator, + std::move(backpressure)), impl_{std::move(impl)} {} const char* kind_name() const override { return "OrderBySinkNode"; } @@ -319,14 +323,13 @@ struct OrderBySinkNode final : public SinkNode { static Result MakeSort(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderBySinkNode")); - + auto schema = inputs[0]->output_schema(); const auto& sink_options = checked_cast(options); ARROW_ASSIGN_OR_RAISE( std::unique_ptr impl, - OrderByImpl::MakeSort(plan->exec_context(), inputs[0]->output_schema(), - sink_options.sort_options)); - return plan->EmplaceNode(plan, std::move(inputs), std::move(impl), - sink_options.generator, + OrderByImpl::MakeSort(plan->exec_context(), schema, sink_options.sort_options)); + return plan->EmplaceNode(plan, std::move(inputs), std::move(schema), + std::move(impl), sink_options.generator, sink_options.backpressure); } @@ -334,14 +337,13 @@ struct OrderBySinkNode final : public SinkNode { static Result MakeSelectK(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderBySinkNode")); - + auto schema = inputs[0]->output_schema(); const auto& sink_options = checked_cast(options); - ARROW_ASSIGN_OR_RAISE( - std::unique_ptr impl, - OrderByImpl::MakeSelectK(plan->exec_context(), inputs[0]->output_schema(), - sink_options.select_k_options)); - return plan->EmplaceNode(plan, std::move(inputs), std::move(impl), - sink_options.generator, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr impl, + OrderByImpl::MakeSelectK(plan->exec_context(), schema, + sink_options.select_k_options)); + return plan->EmplaceNode(plan, std::move(inputs), std::move(schema), + std::move(impl), sink_options.generator, sink_options.backpressure); }