diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index c20dc0d048c..be2f23ad24b 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -316,7 +316,7 @@ class ARROW_EXPORT MapNode : public ExecNode { protected: void SubmitTask(std::function(ExecBatch)> map_fn, ExecBatch batch); - void Finish(Status finish_st = Status::OK()); + virtual void Finish(Status finish_st = Status::OK()); protected: // Counter for the number of batches received diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index 56573f61d7c..bd6c3b79b82 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -363,7 +363,7 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl { } protected: - virtual void Finish(const Status& finish_st) { + void Finish(const Status& finish_st) { consumer_->Finish().AddCallback([this, finish_st](const Status& st) { // Prefer the plan error over the consumer error Status final_status = finish_st & st; diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 7e72497186e..ec2b91050df 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -231,7 +231,7 @@ struct TableSourceNode : public SourceNode { static arrow::Status ValidateTableSourceNodeInput(const std::shared_ptr table, const int64_t batch_size) { if (table == nullptr) { - return Status::Invalid("TableSourceNode node requires table which is not null"); + return Status::Invalid("TableSourceNode requires table which is not null"); } if (batch_size <= 0) { diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 41eb401ced6..3f5d094774c 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -165,6 +165,12 @@ ExecBatch ExecBatchFromJSON(const std::vector& descrs, return batch; } +Future<> StartAndFinish(ExecPlan* plan) { + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); + return plan->finished(); +} + Future> StartAndCollect( ExecPlan* plan, AsyncGenerator> gen) { RETURN_NOT_OK(plan->Validate()); diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 9347d1343f1..9cb615ac450 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -82,6 +82,9 @@ struct BatchesWithSchema { } }; +ARROW_TESTING_EXPORT +Future<> StartAndFinish(ExecPlan* plan); + ARROW_TESTING_EXPORT Future> StartAndCollect( ExecPlan* plan, AsyncGenerator> gen); diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc index ef56e6128a3..f6ac70ad45a 100644 --- a/cpp/src/arrow/compute/exec/util.cc +++ b/cpp/src/arrow/compute/exec/util.cc @@ -287,7 +287,7 @@ namespace compute { Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector& inputs, int expected_num_inputs, const char* kind_name) { if (static_cast(inputs.size()) != expected_num_inputs) { - return Status::Invalid(kind_name, " node requires ", expected_num_inputs, + return Status::Invalid(kind_name, " requires ", expected_num_inputs, " inputs but got ", inputs.size()); } diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 822fc714623..10277810575 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -460,6 +460,16 @@ class TeeNode : public compute::MapNode { const char* kind_name() const override { return "TeeNode"; } + void Finish(Status finish_st) override { + dataset_writer_->Finish().AddCallback([this, finish_st](const Status& dw_status) { + // Need to wait for the task group to complete regardless of dw_status + task_group_.End().AddCallback( + [this, dw_status, finish_st](const Status& tg_status) { + finished_.MarkFinished(dw_status & finish_st & tg_status); + }); + }); + } + Result DoTee(const compute::ExecBatch& batch) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr record_batch, batch.ToRecordBatch(output_schema())); diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index 226c23ef5e4..4dfc6bc584d 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -18,14 +18,17 @@ #include #include #include +#include #include #include #include #include "arrow/array/array_primitive.h" +#include "arrow/compute/exec/test_util.h" #include "arrow/dataset/api.h" #include "arrow/dataset/partition.h" +#include "arrow/dataset/plan.h" #include "arrow/dataset/test_util.h" #include "arrow/filesystem/path_util.h" #include "arrow/filesystem/test_util.h" @@ -34,6 +37,8 @@ #include "arrow/testing/gtest_util.h" #include "arrow/util/io_util.h" +namespace cp = arrow::compute; + namespace arrow { using internal::TemporaryDir; @@ -342,5 +347,107 @@ TEST_F(TestFileSystemDataset, WriteProjected) { } } } + +class FileSystemWriteTest : public testing::TestWithParam> { + using PlanFactory = std::function( + const FileSystemDatasetWriteOptions&, + std::function>()>*)>; + + protected: + bool IsParallel() { return std::get<0>(GetParam()); } + bool IsSlow() { return std::get<1>(GetParam()); } + + FileSystemWriteTest() { dataset::internal::Initialize(); } + + void TestDatasetWriteRoundTrip(PlanFactory plan_factory, bool has_output) { + // Runs in-memory data through the plan and then scans out the written + // data to ensure it matches the source data + auto format = std::make_shared(); + auto fs = std::make_shared(fs::kNoTime); + FileSystemDatasetWriteOptions write_options; + write_options.file_write_options = format->DefaultWriteOptions(); + write_options.filesystem = fs; + write_options.base_dir = "root"; + write_options.partitioning = std::make_shared(schema({})); + write_options.basename_template = "{i}.feather"; + const std::string kExpectedFilename = "root/0.feather"; + + cp::BatchesWithSchema source_data; + source_data.batches = { + cp::ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"), + cp::ExecBatchFromJSON({int32(), boolean()}, + "[[5, null], [6, false], [7, false]]")}; + source_data.schema = schema({field("i32", int32()), field("bool", boolean())}); + + AsyncGenerator> sink_gen; + + ASSERT_OK_AND_ASSIGN(auto plan, cp::ExecPlan::Make()); + auto source_decl = cp::Declaration::Sequence( + {{"source", cp::SourceNodeOptions{source_data.schema, + source_data.gen(IsParallel(), IsSlow())}}}); + auto declarations = plan_factory(write_options, &sink_gen); + declarations.insert(declarations.begin(), std::move(source_decl)); + ASSERT_OK(cp::Declaration::Sequence(std::move(declarations)).AddToPlan(plan.get())); + + if (has_output) { + ASSERT_FINISHES_OK_AND_ASSIGN(auto out_batches, + cp::StartAndCollect(plan.get(), sink_gen)); + cp::AssertExecBatchesEqual(source_data.schema, source_data.batches, out_batches); + } else { + ASSERT_FINISHES_OK(cp::StartAndFinish(plan.get())); + } + + // Read written dataset and make sure it matches + ASSERT_OK_AND_ASSIGN(auto dataset_factory, FileSystemDatasetFactory::Make( + fs, {kExpectedFilename}, format, {})); + ASSERT_OK_AND_ASSIGN(auto written_dataset, dataset_factory->Finish(FinishOptions{})); + AssertSchemaEqual(*source_data.schema, *written_dataset->schema()); + + ASSERT_OK_AND_ASSIGN(plan, cp::ExecPlan::Make()); + ASSERT_OK_AND_ASSIGN(auto scanner_builder, written_dataset->NewScan()); + ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish()); + ASSERT_OK(cp::Declaration::Sequence( + { + {"scan", ScanNodeOptions{written_dataset, scanner->options()}}, + {"sink", cp::SinkNodeOptions{&sink_gen}}, + }) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto written_batches, + cp::StartAndCollect(plan.get(), sink_gen)); + cp::AssertExecBatchesEqual(source_data.schema, source_data.batches, written_batches); + } +}; + +TEST_P(FileSystemWriteTest, Write) { + auto plan_factory = + [](const FileSystemDatasetWriteOptions& write_options, + std::function>()>* sink_gen) { + return std::vector{{"write", WriteNodeOptions{write_options}}}; + }; + TestDatasetWriteRoundTrip(plan_factory, /*has_output=*/false); +} + +TEST_P(FileSystemWriteTest, TeeWrite) { + auto plan_factory = + [](const FileSystemDatasetWriteOptions& write_options, + std::function>()>* sink_gen) { + return std::vector{ + {"tee", WriteNodeOptions{write_options}}, + {"sink", cp::SinkNodeOptions{sink_gen}}, + }; + }; + TestDatasetWriteRoundTrip(plan_factory, /*has_output=*/true); +} + +INSTANTIATE_TEST_SUITE_P( + FileSystemWrite, FileSystemWriteTest, + testing::Combine(testing::Values(false, true), testing::Values(false, true)), + [](const testing::TestParamInfo& info) { + std::string parallel_desc = std::get<0>(info.param) ? "parallel" : "serial"; + std::string speed_desc = std::get<1>(info.param) ? "slow" : "fast"; + return parallel_desc + "_" + speed_desc; + }); + } // namespace dataset } // namespace arrow