Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions cpp/src/arrow/compute/exec/plan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,71 @@ TEST(ExecPlanExecution, SourceTableConsumingSink) {
}
}

TEST(ExecPlanExecution, SinkNodeOutputSchema) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator<util::optional<ExecBatch>> sink_gen;

auto basic_data = MakeBasicBatches();

ASSERT_OK(
Declaration::Sequence({
{"source", SourceNodeOptions{basic_data.schema,
basic_data.gen(true, true)}},
{"sink", SinkNodeOptions{&sink_gen}},
})
.AddToPlan(plan.get()));
ASSERT_EQ(plan->sources()[0]->output_schema(), plan->sinks()[0]->output_schema());
}

TEST(ExecPlanExecution, ConsumingSinkNodeOutputSchema) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
std::atomic<uint32_t> batches_seen{0};
Future<> finish = Future<>::Make();
struct TestConsumer : public SinkNodeConsumer {
TestConsumer(std::atomic<uint32_t>* batches_seen, Future<> finish)
: batches_seen(batches_seen), finish(std::move(finish)) {}

Status Consume(ExecBatch batch) override {
(*batches_seen)++;
return Status::OK();
}

Future<> Finish() override { return finish; }

std::atomic<uint32_t>* batches_seen;
Future<> finish;
};
std::shared_ptr<TestConsumer> consumer =
std::make_shared<TestConsumer>(&batches_seen, finish);

auto basic_data = MakeBasicBatches();
ASSERT_OK_AND_ASSIGN(
auto source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions(basic_data.schema, basic_data.gen(true, true))));
ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source},
ConsumingSinkNodeOptions(consumer)));
ASSERT_EQ(plan->sources()[0]->output_schema(), plan->sinks()[0]->output_schema());
}

TEST(ExecPlanExecution, OrderBySinkNodeOutputSchema) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
auto input_schema = schema({field("a", int32()), field("b", boolean())});
AsyncGenerator<util::optional<ExecBatch>> sink_gen;

auto random_data = MakeRandomBatches(input_schema, 10);

SortOptions options({SortKey("a", SortOrder::Ascending)});
ASSERT_OK(Declaration::Sequence(
{
{"source",
SourceNodeOptions{random_data.schema, random_data.gen(true, true)}},
{"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
})
.AddToPlan(plan.get()));
ASSERT_EQ(plan->sources()[0]->output_schema(), plan->sinks()[0]->output_schema());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to just add these checks to existing tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can. I separated it in case it diverts from the purpose. There is code duplication with the current approach.

}

TEST(ExecPlanExecution, ConsumingSinkError) {
struct ConsumeErrorConsumer : public SinkNodeConsumer {
Status Consume(ExecBatch batch) override { return Status::Invalid("XYZ"); }
Expand Down
44 changes: 23 additions & 21 deletions cpp/src/arrow/compute/exec/sink_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,20 @@ namespace {
class SinkNode : public ExecNode {
public:
SinkNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema,
AsyncGenerator<util::optional<ExecBatch>>* generator,
util::BackpressureOptions backpressure)
: ExecNode(plan, std::move(inputs), {"collected"}, {},
: ExecNode(plan, std::move(inputs), {"collected"}, std::move(output_schema),
/*num_outputs=*/0),
producer_(MakeProducer(generator, std::move(backpressure))) {}

static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "SinkNode"));

auto schema = inputs[0]->output_schema();
const auto& sink_options = checked_cast<const SinkNodeOptions&>(options);
return plan->EmplaceNode<SinkNode>(plan, std::move(inputs), sink_options.generator,
sink_options.backpressure);
return plan->EmplaceNode<SinkNode>(plan, std::move(inputs), std::move(schema),
sink_options.generator, sink_options.backpressure);
}

static PushGenerator<util::optional<ExecBatch>>::Producer MakeProducer(
Expand Down Expand Up @@ -157,18 +158,19 @@ class SinkNode : public ExecNode {
class ConsumingSinkNode : public ExecNode {
public:
ConsumingSinkNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema,
std::shared_ptr<SinkNodeConsumer> consumer)
: ExecNode(plan, std::move(inputs), {"to_consume"}, {},
: ExecNode(plan, std::move(inputs), {"to_consume"}, std::move(output_schema),
/*num_outputs=*/0),
consumer_(std::move(consumer)) {}

static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "SinkNode"));

auto schema = inputs[0]->output_schema();
const auto& sink_options = checked_cast<const ConsumingSinkNodeOptions&>(options);
return plan->EmplaceNode<ConsumingSinkNode>(plan, std::move(inputs),
std::move(sink_options.consumer));
return plan->EmplaceNode<ConsumingSinkNode>(
plan, std::move(inputs), std::move(schema), std::move(sink_options.consumer));
}

const char* kind_name() const override { return "ConsumingSinkNode"; }
Expand Down Expand Up @@ -307,10 +309,12 @@ static Result<ExecNode*> MakeTableConsumingSinkNode(
// A sink node that accumulates inputs, then sorts them before emitting them.
struct OrderBySinkNode final : public SinkNode {
OrderBySinkNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema,
std::unique_ptr<OrderByImpl> impl,
AsyncGenerator<util::optional<ExecBatch>>* generator,
util::BackpressureOptions backpressure)
: SinkNode(plan, std::move(inputs), generator, std::move(backpressure)),
: SinkNode(plan, std::move(inputs), std::move(output_schema), generator,
std::move(backpressure)),
impl_{std::move(impl)} {}

const char* kind_name() const override { return "OrderBySinkNode"; }
Expand All @@ -319,29 +323,27 @@ struct OrderBySinkNode final : public SinkNode {
static Result<ExecNode*> MakeSort(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderBySinkNode"));

auto schema = inputs[0]->output_schema();
const auto& sink_options = checked_cast<const OrderBySinkNodeOptions&>(options);
ARROW_ASSIGN_OR_RAISE(
std::unique_ptr<OrderByImpl> impl,
OrderByImpl::MakeSort(plan->exec_context(), inputs[0]->output_schema(),
sink_options.sort_options));
return plan->EmplaceNode<OrderBySinkNode>(plan, std::move(inputs), std::move(impl),
sink_options.generator,
OrderByImpl::MakeSort(plan->exec_context(), schema, sink_options.sort_options));
return plan->EmplaceNode<OrderBySinkNode>(plan, std::move(inputs), std::move(schema),
std::move(impl), sink_options.generator,
sink_options.backpressure);
}

// A sink node that receives inputs and then compute top_k/bottom_k.
static Result<ExecNode*> MakeSelectK(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderBySinkNode"));

auto schema = inputs[0]->output_schema();
const auto& sink_options = checked_cast<const SelectKSinkNodeOptions&>(options);
ARROW_ASSIGN_OR_RAISE(
std::unique_ptr<OrderByImpl> impl,
OrderByImpl::MakeSelectK(plan->exec_context(), inputs[0]->output_schema(),
sink_options.select_k_options));
return plan->EmplaceNode<OrderBySinkNode>(plan, std::move(inputs), std::move(impl),
sink_options.generator,
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<OrderByImpl> impl,
OrderByImpl::MakeSelectK(plan->exec_context(), schema,
sink_options.select_k_options));
return plan->EmplaceNode<OrderBySinkNode>(plan, std::move(inputs), std::move(schema),
std::move(impl), sink_options.generator,
sink_options.backpressure);
}

Expand Down