diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 5c64cf2fc30..1de81cd41dd 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -101,7 +101,7 @@ struct ExecPlanImpl : public ExecPlan { futures.push_back(node->finished()); } - finished_ = AllComplete(std::move(futures)); + finished_ = AllFinished(futures); return st; } diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index d886eb92e2c..cde118acd8a 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -112,6 +112,26 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions { std::function>()>* generator; }; +class ARROW_EXPORT SinkNodeConsumer { + public: + virtual ~SinkNodeConsumer() = default; + /// \brief Consume a batch of data + virtual Status Consume(ExecBatch batch) = 0; + /// \brief Signal to the consumer that the last batch has been delivered + /// + /// The returned future should only finish when all outstanding tasks have completed + virtual Future<> Finish() = 0; +}; + +/// \brief Add a sink node which consumes data within the exec plan run +class ARROW_EXPORT ConsumingSinkNodeOptions : public ExecNodeOptions { + public: + explicit ConsumingSinkNodeOptions(std::shared_ptr consumer) + : consumer(std::move(consumer)) {} + + std::shared_ptr consumer; +}; + /// \brief Make a node which sorts rows passed through it /// /// All batches pushed to this node will be accumulated, then sorted, by the given diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 85b657fe118..0d20050dd44 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -374,6 +374,108 @@ TEST(ExecPlanExecution, SourceSinkError) { Finishes(Raises(StatusCode::Invalid, HasSubstr("Artificial")))); } +TEST(ExecPlanExecution, SourceConsumingSink) { + for (bool slow : {false, true}) { + SCOPED_TRACE(slow ? "slowed" : "unslowed"); + + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel" : "single threaded"); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + std::atomic batches_seen{0}; + Future<> finish = Future<>::Make(); + struct TestConsumer : public SinkNodeConsumer { + TestConsumer(std::atomic* batches_seen, Future<> finish) + : batches_seen(batches_seen), finish(std::move(finish)) {} + + Status Consume(ExecBatch batch) override { + (*batches_seen)++; + return Status::OK(); + } + + Future<> Finish() override { return finish; } + + std::atomic* batches_seen; + Future<> finish; + }; + std::shared_ptr consumer = + std::make_shared(&batches_seen, finish); + + auto basic_data = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN( + auto source, MakeExecNode("source", plan.get(), {}, + SourceNodeOptions(basic_data.schema, + basic_data.gen(parallel, slow)))); + ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source}, + ConsumingSinkNodeOptions(consumer))); + ASSERT_OK(plan->StartProducing()); + // Source should finish fairly quickly + ASSERT_FINISHES_OK(source->finished()); + SleepABit(); + ASSERT_EQ(2, batches_seen); + // Consumer isn't finished and so plan shouldn't have finished + AssertNotFinished(plan->finished()); + // Mark consumption complete, plan should finish + finish.MarkFinished(); + ASSERT_FINISHES_OK(plan->finished()); + } + } +} + +TEST(ExecPlanExecution, ConsumingSinkError) { + struct ConsumeErrorConsumer : public SinkNodeConsumer { + Status Consume(ExecBatch batch) override { return Status::Invalid("XYZ"); } + Future<> Finish() override { return Future<>::MakeFinished(); } + }; + struct FinishErrorConsumer : public SinkNodeConsumer { + Status Consume(ExecBatch batch) override { return Status::OK(); } + Future<> Finish() override { return Future<>::MakeFinished(Status::Invalid("XYZ")); } + }; + std::vector> consumers{ + std::make_shared(), std::make_shared()}; + + for (auto& consumer : consumers) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + auto basic_data = MakeBasicBatches(); + ASSERT_OK(Declaration::Sequence( + {{"source", + SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))}, + {"consuming_sink", ConsumingSinkNodeOptions(consumer)}}) + .AddToPlan(plan.get())); + ASSERT_OK_AND_ASSIGN( + auto source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions(basic_data.schema, basic_data.gen(false, false)))); + ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source}, + ConsumingSinkNodeOptions(consumer))); + ASSERT_OK(plan->StartProducing()); + ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished()); + } +} + +TEST(ExecPlanExecution, ConsumingSinkErrorFinish) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + struct FinishErrorConsumer : public SinkNodeConsumer { + Status Consume(ExecBatch batch) override { return Status::OK(); } + Future<> Finish() override { return Future<>::MakeFinished(Status::Invalid("XYZ")); } + }; + std::shared_ptr consumer = std::make_shared(); + + auto basic_data = MakeBasicBatches(); + ASSERT_OK( + Declaration::Sequence( + {{"source", SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))}, + {"consuming_sink", ConsumingSinkNodeOptions(consumer)}}) + .AddToPlan(plan.get())); + ASSERT_OK_AND_ASSIGN( + auto source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions(basic_data.schema, basic_data.gen(false, false)))); + ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source}, + ConsumingSinkNodeOptions(consumer))); + ASSERT_OK(plan->StartProducing()); + ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished()); +} + TEST(ExecPlanExecution, StressSourceSink) { for (bool slow : {false, true}) { SCOPED_TRACE(slow ? "slowed" : "unslowed"); diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index 76889c244d7..d84f3c44115 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -31,6 +31,7 @@ #include "arrow/result.h" #include "arrow/table.h" #include "arrow/util/async_generator.h" +#include "arrow/util/async_util.h" #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" #include "arrow/util/logging.h" @@ -132,6 +133,104 @@ class SinkNode : public ExecNode { PushGenerator>::Producer producer_; }; +// A sink node that owns consuming the data and will not finish until the consumption +// is finished. Use SinkNode if you are transferring the ownership of the data to another +// system. Use ConsumingSinkNode if the data is being consumed within the exec plan (i.e. +// the exec plan should not complete until the consumption has completed). +class ConsumingSinkNode : public ExecNode { + public: + ConsumingSinkNode(ExecPlan* plan, std::vector inputs, + std::shared_ptr consumer) + : ExecNode(plan, std::move(inputs), {"to_consume"}, {}, + /*num_outputs=*/0), + consumer_(std::move(consumer)) {} + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "SinkNode")); + + const auto& sink_options = checked_cast(options); + return plan->EmplaceNode(plan, std::move(inputs), + std::move(sink_options.consumer)); + } + + const char* kind_name() const override { return "ConsumingSinkNode"; } + + Status StartProducing() override { + finished_ = Future<>::Make(); + return Status::OK(); + } + + // sink nodes have no outputs from which to feel backpressure + [[noreturn]] static void NoOutputs() { + Unreachable("no outputs; this should never be called"); + } + [[noreturn]] void ResumeProducing(ExecNode* output) override { NoOutputs(); } + [[noreturn]] void PauseProducing(ExecNode* output) override { NoOutputs(); } + [[noreturn]] void StopProducing(ExecNode* output) override { NoOutputs(); } + + void StopProducing() override { + Finish(Status::Invalid("ExecPlan was stopped early")); + inputs_[0]->StopProducing(this); + } + + Future<> finished() override { return finished_; } + + void InputReceived(ExecNode* input, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + + // This can happen if an error was received and the source hasn't yet stopped. Since + // we have already called consumer_->Finish we don't want to call consumer_->Consume + if (input_counter_.Completed()) { + return; + } + + Status consumption_status = consumer_->Consume(std::move(batch)); + if (!consumption_status.ok()) { + if (input_counter_.Cancel()) { + Finish(std::move(consumption_status)); + } + inputs_[0]->StopProducing(this); + return; + } + + if (input_counter_.Increment()) { + Finish(Status::OK()); + } + } + + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); + + if (input_counter_.Cancel()) { + Finish(std::move(error)); + } + + inputs_[0]->StopProducing(this); + } + + void InputFinished(ExecNode* input, int total_batches) override { + if (input_counter_.SetTotal(total_batches)) { + Finish(Status::OK()); + } + } + + protected: + virtual void Finish(const Status& finish_st) { + consumer_->Finish().AddCallback([this, finish_st](const Status& st) { + // Prefer the plan error over the consumer error + Status final_status = finish_st & st; + finished_.MarkFinished(std::move(final_status)); + }); + } + + AtomicCounter input_counter_; + + Future<> finished_ = Future<>::MakeFinished(); + std::shared_ptr consumer_; +}; + +// A sink node that accumulates inputs, then sorts them before emitting them. struct OrderBySinkNode final : public SinkNode { OrderBySinkNode(ExecPlan* plan, std::vector inputs, std::unique_ptr impl, @@ -226,6 +325,7 @@ namespace internal { void RegisterSinkNode(ExecFactoryRegistry* registry) { DCHECK_OK(registry->AddFactory("select_k_sink", OrderBySinkNode::MakeSelectK)); DCHECK_OK(registry->AddFactory("order_by_sink", OrderBySinkNode::MakeSort)); + DCHECK_OK(registry->AddFactory("consuming_sink", ConsumingSinkNode::Make)); DCHECK_OK(registry->AddFactory("sink", SinkNode::Make)); } diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index f64b009f39f..127a1b4f9b3 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -79,60 +79,59 @@ struct SourceNode : ExecNode { options.executor = executor; options.should_schedule = ShouldSchedule::IfDifferentExecutor; } - finished_ = Loop([this, executor, options] { - std::unique_lock lock(mutex_); - int total_batches = batch_count_++; - if (stop_requested_) { - return Future>::MakeFinished(Break(total_batches)); + finished_ = + Loop([this, executor, options] { + std::unique_lock lock(mutex_); + int total_batches = batch_count_++; + if (stop_requested_) { + return Future>::MakeFinished(Break(total_batches)); + } + lock.unlock(); + + return generator_().Then( + [=](const util::optional& maybe_batch) -> ControlFlow { + std::unique_lock lock(mutex_); + if (IsIterationEnd(maybe_batch) || stop_requested_) { + stop_requested_ = true; + return Break(total_batches); + } + lock.unlock(); + ExecBatch batch = std::move(*maybe_batch); + + if (executor) { + auto status = + task_group_.AddTask([this, executor, batch]() -> Result> { + return executor->Submit([=]() { + outputs_[0]->InputReceived(this, std::move(batch)); + return Status::OK(); + }); + }); + if (!status.ok()) { + outputs_[0]->ErrorReceived(this, std::move(status)); + return Break(total_batches); } - lock.unlock(); - - return generator_().Then( - [=](const util::optional& batch) -> ControlFlow { - std::unique_lock lock(mutex_); - if (IsIterationEnd(batch) || stop_requested_) { - stop_requested_ = true; - return Break(total_batches); - } - lock.unlock(); - - if (executor) { - auto maybe_future = executor->Submit([=]() { - outputs_[0]->InputReceived(this, *batch); - return Status::OK(); - }); - if (!maybe_future.ok()) { - outputs_[0]->ErrorReceived(this, maybe_future.status()); - return Break(total_batches); - } - auto status = - task_group_.AddTask(maybe_future.MoveValueUnsafe()); - if (!status.ok()) { - outputs_[0]->ErrorReceived(this, std::move(status)); - return Break(total_batches); - } - } else { - outputs_[0]->InputReceived(this, *batch); - } - return Continue(); - }, - [=](const Status& error) -> ControlFlow { - // NB: ErrorReceived is independent of InputFinished, but - // ErrorReceived will usually prompt StopProducing which will - // prompt InputFinished. ErrorReceived may still be called from a - // node which was requested to stop (indeed, the request to stop - // may prompt an error). - std::unique_lock lock(mutex_); - stop_requested_ = true; - lock.unlock(); - outputs_[0]->ErrorReceived(this, error); - return Break(total_batches); - }, - options); - }).Then([&](int total_batches) { - outputs_[0]->InputFinished(this, total_batches); - return task_group_.WaitForTasksToFinish(); - }); + } else { + outputs_[0]->InputReceived(this, std::move(batch)); + } + return Continue(); + }, + [=](const Status& error) -> ControlFlow { + // NB: ErrorReceived is independent of InputFinished, but + // ErrorReceived will usually prompt StopProducing which will + // prompt InputFinished. ErrorReceived may still be called from a + // node which was requested to stop (indeed, the request to stop + // may prompt an error). + std::unique_lock lock(mutex_); + stop_requested_ = true; + lock.unlock(); + outputs_[0]->ErrorReceived(this, error); + return Break(total_batches); + }, + options); + }).Then([&](int total_batches) { + outputs_[0]->InputFinished(this, total_batches); + return task_group_.End(); + }); return Status::OK(); } diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index ed89bece6a3..10c848968f2 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -235,6 +235,9 @@ class AtomicCounter { // return true if the counter has not already been completed bool Cancel() { return DoneOnce(); } + // return true if the counter has finished or been cancelled + bool Completed() { return complete_.load(); } + private: // ensure there is only one true return from Increment(), SetTotal(), or Cancel() bool DoneOnce() { diff --git a/cpp/src/arrow/dataset/CMakeLists.txt b/cpp/src/arrow/dataset/CMakeLists.txt index 658eb0f9172..c601e9fb1e2 100644 --- a/cpp/src/arrow/dataset/CMakeLists.txt +++ b/cpp/src/arrow/dataset/CMakeLists.txt @@ -26,6 +26,7 @@ set(ARROW_DATASET_SRCS file_base.cc file_ipc.cc partition.cc + plan.cc projector.cc scanner.cc) diff --git a/cpp/src/arrow/dataset/dataset_writer.cc b/cpp/src/arrow/dataset/dataset_writer.cc index 6233b4bf4af..12b7858c4b9 100644 --- a/cpp/src/arrow/dataset/dataset_writer.cc +++ b/cpp/src/arrow/dataset/dataset_writer.cc @@ -324,7 +324,7 @@ class DatasetWriterDirectoryQueue : public util::AsyncDestroyable { Future<> DoDestroy() override { latest_open_file_.reset(); - return task_group_.WaitForTasksToFinish(); + return task_group_.End(); } private: @@ -482,7 +482,7 @@ class DatasetWriter::DatasetWriterImpl : public util::AsyncDestroyable { Future<> DoDestroy() override { directory_queues_.clear(); - return task_group_.WaitForTasksToFinish().Then([this] { return err_; }); + return task_group_.End().Then([this] { return err_; }); } util::AsyncTaskGroup task_group_; diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index ec65bf12e23..24eba5a496f 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -17,6 +17,8 @@ #include "arrow/dataset/file_base.h" +#include + #include #include #include @@ -322,72 +324,124 @@ Status FileWriter::Finish() { namespace { -Future<> WriteNextBatch(internal::DatasetWriter* dataset_writer, TaggedRecordBatch batch, - const FileSystemDatasetWriteOptions& write_options) { - ARROW_ASSIGN_OR_RAISE(auto groups, - write_options.partitioning->Partition(batch.record_batch)); - batch.record_batch.reset(); // drop to hopefully conserve memory +class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer { + public: + DatasetWritingSinkNodeConsumer(std::shared_ptr schema, + std::unique_ptr dataset_writer, + FileSystemDatasetWriteOptions write_options) + : schema(std::move(schema)), + dataset_writer(std::move(dataset_writer)), + write_options(std::move(write_options)) {} + + Status Consume(compute::ExecBatch batch) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr record_batch, + batch.ToRecordBatch(schema)); + return WriteNextBatch(std::move(record_batch), batch.guarantee); + } - if (groups.batches.size() > static_cast(write_options.max_partitions)) { - return Status::Invalid("Fragment would be written into ", groups.batches.size(), - " partitions. This exceeds the maximum of ", - write_options.max_partitions); + Future<> Finish() { + RETURN_NOT_OK(task_group.AddTask([this] { return dataset_writer->Finish(); })); + return task_group.End(); } - std::shared_ptr counter = std::make_shared(0); - std::shared_ptr fragment = std::move(batch.fragment); + private: + Status WriteNextBatch(std::shared_ptr batch, + compute::Expression guarantee) { + ARROW_ASSIGN_OR_RAISE(auto groups, write_options.partitioning->Partition(batch)); + batch.reset(); // drop to hopefully conserve memory - AsyncGenerator> partitioned_batch_gen = - [groups, counter, fragment, &write_options, - dataset_writer]() -> Future> { - auto index = *counter; - if (index >= groups.batches.size()) { - return AsyncGeneratorEnd>(); + if (groups.batches.size() > static_cast(write_options.max_partitions)) { + return Status::Invalid("Fragment would be written into ", groups.batches.size(), + " partitions. This exceeds the maximum of ", + write_options.max_partitions); } - auto partition_expression = - and_(groups.expressions[index], fragment->partition_expression()); - auto next_batch = groups.batches[index]; - ARROW_ASSIGN_OR_RAISE(std::string destination, - write_options.partitioning->Format(partition_expression)); - (*counter)++; - return dataset_writer->WriteRecordBatch(next_batch, destination).Then([next_batch] { - return next_batch; - }); - }; - return VisitAsyncGenerator( - std::move(partitioned_batch_gen), - [](const std::shared_ptr&) -> Status { return Status::OK(); }); -} + for (std::size_t index = 0; index < groups.batches.size(); index++) { + auto partition_expression = and_(groups.expressions[index], guarantee); + auto next_batch = groups.batches[index]; + ARROW_ASSIGN_OR_RAISE(std::string destination, + write_options.partitioning->Format(partition_expression)); + RETURN_NOT_OK(task_group.AddTask([this, next_batch, destination] { + return dataset_writer->WriteRecordBatch(next_batch, destination); + })); + } + return Status::OK(); + } + + std::shared_ptr schema; + std::unique_ptr dataset_writer; + FileSystemDatasetWriteOptions write_options; + + util::SerializedAsyncTaskGroup task_group; +}; } // namespace Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_options, std::shared_ptr scanner) { - ARROW_ASSIGN_OR_RAISE(auto batch_gen, scanner->ScanBatchesAsync()); + const io::IOContext& io_context = scanner->options()->io_context; + std::shared_ptr exec_context = + std::make_shared(io_context.pool(), + ::arrow::internal::GetCpuThreadPool()); + + ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(exec_context.get())); + + auto exprs = scanner->options()->projection.call()->arguments; + auto names = checked_cast( + scanner->options()->projection.call()->options.get()) + ->field_names; + std::shared_ptr dataset = scanner->dataset(); + + RETURN_NOT_OK( + compute::Declaration::Sequence( + { + {"scan", ScanNodeOptions{dataset, scanner->options()}}, + {"filter", compute::FilterNodeOptions{scanner->options()->filter}}, + {"project", + compute::ProjectNodeOptions{std::move(exprs), std::move(names)}}, + {"write", + WriteNodeOptions{write_options, scanner->options()->projected_schema}}, + }) + .AddToPlan(plan.get())); + + RETURN_NOT_OK(plan->StartProducing()); + return plan->finished().status(); +} + +Result MakeWriteNode(compute::ExecPlan* plan, + std::vector inputs, + const compute::ExecNodeOptions& options) { + if (inputs.size() != 1) { + return Status::Invalid("Write SinkNode requires exactly 1 input, got ", + inputs.size()); + } + + const WriteNodeOptions write_node_options = + checked_cast(options); + const FileSystemDatasetWriteOptions& write_options = write_node_options.write_options; + std::shared_ptr schema = write_node_options.schema; + ARROW_ASSIGN_OR_RAISE(auto dataset_writer, internal::DatasetWriter::Make(write_options)); - AsyncGenerator> queued_batch_gen = - [batch_gen, &dataset_writer, &write_options]() -> Future> { - Future next_batch_fut = batch_gen(); - return next_batch_fut.Then( - [&dataset_writer, &write_options](const TaggedRecordBatch& batch) { - if (IsIterationEnd(batch)) { - return AsyncGeneratorEnd>(); - } - return WriteNextBatch(dataset_writer.get(), batch, write_options).Then([] { - return std::make_shared(0); - }); - }); - }; - Future<> queue_fut = - VisitAsyncGenerator(std::move(queued_batch_gen), - [&](const std::shared_ptr&) { return Status::OK(); }); + std::shared_ptr consumer = + std::make_shared( + std::move(schema), std::move(dataset_writer), write_options); + + ARROW_ASSIGN_OR_RAISE( + auto node, + compute::MakeExecNode("consuming_sink", plan, std::move(inputs), + compute::ConsumingSinkNodeOptions{std::move(consumer)})); - ARROW_RETURN_NOT_OK(queue_fut.status()); - return dataset_writer->Finish().status(); + return node; } +namespace internal { +void InitializeDatasetWriter(arrow::compute::ExecFactoryRegistry* registry) { + DCHECK_OK(registry->AddFactory("write", MakeWriteNode)); +} +} // namespace internal + } // namespace dataset + } // namespace arrow diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index fc5e17b0c2d..a645c2c8b08 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -407,7 +407,23 @@ struct ARROW_DS_EXPORT FileSystemDatasetWriteOptions { } }; +/// \brief Wraps FileSystemDatasetWriteOptions for consumption as compute::ExecNodeOptions +class ARROW_DS_EXPORT WriteNodeOptions : public compute::ExecNodeOptions { + public: + explicit WriteNodeOptions(FileSystemDatasetWriteOptions options, + std::shared_ptr schema) + : write_options(std::move(options)), schema(std::move(schema)) {} + + FileSystemDatasetWriteOptions write_options; + std::shared_ptr schema; +}; + /// @} +namespace internal { +ARROW_DS_EXPORT void InitializeDatasetWriter( + arrow::compute::ExecFactoryRegistry* registry); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/plan.cc b/cpp/src/arrow/dataset/plan.cc new file mode 100644 index 00000000000..9b222ff578c --- /dev/null +++ b/cpp/src/arrow/dataset/plan.cc @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/plan.h" + +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/dataset/file_base.h" +#include "arrow/dataset/scanner.h" + +namespace arrow { +namespace dataset { +namespace internal { + +void Initialize() { + static auto registry = compute::default_exec_factory_registry(); + if (registry) { + InitializeScanner(registry); + InitializeDatasetWriter(registry); + registry = nullptr; + } +} + +} // namespace internal +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/plan.h b/cpp/src/arrow/dataset/plan.h new file mode 100644 index 00000000000..10260ccec81 --- /dev/null +++ b/cpp/src/arrow/dataset/plan.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#include "arrow/dataset/visibility.h" + +namespace arrow { +namespace dataset { +namespace internal { + +/// Register dataset-based exec nodes with the exec node registry +/// +/// This function must be called before using dataset ExecNode factories +ARROW_DS_EXPORT void Initialize(); + +} // namespace internal +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index f19b2372816..433e93172c9 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -31,6 +31,7 @@ #include "arrow/compute/exec/exec_plan.h" #include "arrow/dataset/dataset.h" #include "arrow/dataset/dataset_internal.h" +#include "arrow/dataset/plan.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/table.h" #include "arrow/util/async_generator.h" @@ -339,6 +340,7 @@ class SyncScanner : public Scanner { Result ScanBatchesAsync() override; Result ScanBatchesUnorderedAsync() override; Result CountRows() override; + const std::shared_ptr& dataset() const override; protected: /// \brief GetFragments returns an iterator over all Fragments in this scan. @@ -416,6 +418,8 @@ Result SyncScanner::ScanInternal() { return GetScanTaskIterator(std::move(fragment_it), scan_options_); } +const std::shared_ptr& SyncScanner::dataset() const { return dataset_; } + class AsyncScanner : public Scanner, public std::enable_shared_from_this { public: AsyncScanner(std::shared_ptr dataset, @@ -431,6 +435,7 @@ class AsyncScanner : public Scanner, public std::enable_shared_from_this ScanBatchesUnorderedAsync() override; Result> ToTable() override; Result CountRows() override; + const std::shared_ptr& dataset() const override; private: Result ScanBatchesAsync(Executor* executor); @@ -812,6 +817,8 @@ Result AsyncScanner::CountRows() { return total.load(); } +const std::shared_ptr& AsyncScanner::dataset() const { return dataset_; } + } // namespace ScannerBuilder::ScannerBuilder(std::shared_ptr dataset) @@ -1310,17 +1317,12 @@ Result MakeOrderedSinkNode(compute::ExecPlan* plan, } // namespace namespace internal { - -void Initialize() { - static auto registry = compute::default_exec_factory_registry(); - if (registry) { - DCHECK_OK(registry->AddFactory("scan", MakeScanNode)); - DCHECK_OK(registry->AddFactory("ordered_sink", MakeOrderedSinkNode)); - DCHECK_OK(registry->AddFactory("augmented_project", MakeAugmentedProjectNode)); - registry = nullptr; - } +void InitializeScanner(arrow::compute::ExecFactoryRegistry* registry) { + DCHECK_OK(registry->AddFactory("scan", MakeScanNode)); + DCHECK_OK(registry->AddFactory("ordered_sink", MakeOrderedSinkNode)); + DCHECK_OK(registry->AddFactory("augmented_project", MakeAugmentedProjectNode)); } - } // namespace internal + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index e92ad7d4fc7..9264e9f548a 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -138,7 +138,7 @@ struct ARROW_DS_EXPORT ScanOptions { std::vector MaterializedFields() const; // Return a threaded or serial TaskGroup according to use_threads. - std::shared_ptr TaskGroup() const; + std::shared_ptr<::arrow::internal::TaskGroup> TaskGroup() const; }; /// \brief Read record batches from a range of a single data fragment. A @@ -150,8 +150,8 @@ class ARROW_DS_EXPORT ScanTask { /// resulting from the Scan. Execution semantics are encapsulated in the /// particular ScanTask implementation virtual Result Execute() = 0; - virtual Future SafeExecute(internal::Executor* executor); - virtual Future<> SafeVisit(internal::Executor* executor, + virtual Future SafeExecute(::arrow::internal::Executor* executor); + virtual Future<> SafeVisit(::arrow::internal::Executor* executor, std::function)> visitor); virtual ~ScanTask() = default; @@ -300,6 +300,8 @@ class ARROW_DS_EXPORT Scanner { /// \brief Get the options for this scan. const std::shared_ptr& options() const { return scan_options_; } + /// \brief Get the dataset that this scanner will scan + virtual const std::shared_ptr& dataset() const = 0; protected: explicit Scanner(std::shared_ptr scan_options) @@ -441,10 +443,7 @@ class ARROW_DS_EXPORT InMemoryScanTask : public ScanTask { }; namespace internal { - -/// This function must be called before using dataset ExecNode factories -ARROW_DS_EXPORT void Initialize(); - +ARROW_DS_EXPORT void InitializeScanner(arrow::compute::ExecFactoryRegistry* registry); } // namespace internal } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index ed66fb1cc26..0c6c3277290 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -27,6 +27,7 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" #include "arrow/compute/exec/exec_plan.h" +#include "arrow/dataset/plan.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/dataset/test_util.h" #include "arrow/record_batch.h" diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 993fbd37144..722046e5eff 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -1107,24 +1107,24 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { SchemaFromColumnNames(source_schema_, {"year", "month"}))); expected_files_["/new_root/2018/1/dat_0"] = R"([ - {"region": "NY", "model": "3", "sales": 742.0, "country": "US"}, - {"region": "NY", "model": "S", "sales": 304.125, "country": "US"}, + {"region": "QC", "model": "X", "sales": 1.0, "country": "CA"}, {"region": "NY", "model": "Y", "sales": 27.5, "country": "US"}, - {"region": "QC", "model": "3", "sales": 512, "country": "CA"}, - {"region": "QC", "model": "S", "sales": 978, "country": "CA"}, + {"region": "QC", "model": "Y", "sales": 69, "country": "CA"}, {"region": "NY", "model": "X", "sales": 136.25, "country": "US"}, - {"region": "QC", "model": "X", "sales": 1.0, "country": "CA"}, - {"region": "QC", "model": "Y", "sales": 69, "country": "CA"} + {"region": "NY", "model": "S", "sales": 304.125, "country": "US"}, + {"region": "QC", "model": "3", "sales": 512, "country": "CA"}, + {"region": "NY", "model": "3", "sales": 742.0, "country": "US"}, + {"region": "QC", "model": "S", "sales": 978, "country": "CA"} ])"; expected_files_["/new_root/2019/1/dat_0"] = R"([ - {"region": "CA", "model": "3", "sales": 273.5, "country": "US"}, - {"region": "CA", "model": "S", "sales": 13, "country": "US"}, - {"region": "CA", "model": "X", "sales": 54, "country": "US"}, {"region": "QC", "model": "S", "sales": 10, "country": "CA"}, + {"region": "CA", "model": "S", "sales": 13, "country": "US"}, {"region": "CA", "model": "Y", "sales": 21, "country": "US"}, - {"region": "QC", "model": "3", "sales": 152.25, "country": "CA"}, + {"region": "QC", "model": "Y", "sales": 37, "country": "CA"}, {"region": "QC", "model": "X", "sales": 42, "country": "CA"}, - {"region": "QC", "model": "Y", "sales": 37, "country": "CA"} + {"region": "CA", "model": "X", "sales": 54, "country": "US"}, + {"region": "QC", "model": "3", "sales": 152.25, "country": "CA"}, + {"region": "CA", "model": "3", "sales": 273.5, "country": "US"} ])"; expected_physical_schema_ = SchemaFromColumnNames(source_schema_, {"region", "model", "sales", "country"}); @@ -1139,27 +1139,27 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { // XXX first thing a user will be annoyed by: we don't support left // padding the month field with 0. expected_files_["/new_root/US/NY/dat_0"] = R"([ - {"year": 2018, "month": 1, "model": "3", "sales": 742.0}, - {"year": 2018, "month": 1, "model": "S", "sales": 304.125}, {"year": 2018, "month": 1, "model": "Y", "sales": 27.5}, - {"year": 2018, "month": 1, "model": "X", "sales": 136.25} - ])"; + {"year": 2018, "month": 1, "model": "X", "sales": 136.25}, + {"year": 2018, "month": 1, "model": "S", "sales": 304.125}, + {"year": 2018, "month": 1, "model": "3", "sales": 742.0} + ])"; expected_files_["/new_root/CA/QC/dat_0"] = R"([ - {"year": 2018, "month": 1, "model": "3", "sales": 512}, - {"year": 2018, "month": 1, "model": "S", "sales": 978}, {"year": 2018, "month": 1, "model": "X", "sales": 1.0}, - {"year": 2018, "month": 1, "model": "Y", "sales": 69}, {"year": 2019, "month": 1, "model": "S", "sales": 10}, - {"year": 2019, "month": 1, "model": "3", "sales": 152.25}, + {"year": 2019, "month": 1, "model": "Y", "sales": 37}, {"year": 2019, "month": 1, "model": "X", "sales": 42}, - {"year": 2019, "month": 1, "model": "Y", "sales": 37} - ])"; + {"year": 2018, "month": 1, "model": "Y", "sales": 69}, + {"year": 2019, "month": 1, "model": "3", "sales": 152.25}, + {"year": 2018, "month": 1, "model": "3", "sales": 512}, + {"year": 2018, "month": 1, "model": "S", "sales": 978} + ])"; expected_files_["/new_root/US/CA/dat_0"] = R"([ - {"year": 2019, "month": 1, "model": "3", "sales": 273.5}, {"year": 2019, "month": 1, "model": "S", "sales": 13}, + {"year": 2019, "month": 1, "model": "Y", "sales": 21}, {"year": 2019, "month": 1, "model": "X", "sales": 54}, - {"year": 2019, "month": 1, "model": "Y", "sales": 21} - ])"; + {"year": 2019, "month": 1, "model": "3", "sales": 273.5} + ])"; expected_physical_schema_ = SchemaFromColumnNames(source_schema_, {"model", "sales", "year", "month"}); @@ -1173,29 +1173,29 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { // XXX first thing a user will be annoyed by: we don't support left // padding the month field with 0. expected_files_["/new_root/2018/1/US/NY/dat_0"] = R"([ - {"model": "3", "sales": 742.0}, - {"model": "S", "sales": 304.125}, {"model": "Y", "sales": 27.5}, - {"model": "X", "sales": 136.25} - ])"; + {"model": "X", "sales": 136.25}, + {"model": "S", "sales": 304.125}, + {"model": "3", "sales": 742.0} + ])"; expected_files_["/new_root/2018/1/CA/QC/dat_0"] = R"([ - {"model": "3", "sales": 512}, - {"model": "S", "sales": 978}, {"model": "X", "sales": 1.0}, - {"model": "Y", "sales": 69} - ])"; + {"model": "Y", "sales": 69}, + {"model": "3", "sales": 512}, + {"model": "S", "sales": 978} + ])"; expected_files_["/new_root/2019/1/US/CA/dat_0"] = R"([ - {"model": "3", "sales": 273.5}, {"model": "S", "sales": 13}, + {"model": "Y", "sales": 21}, {"model": "X", "sales": 54}, - {"model": "Y", "sales": 21} - ])"; + {"model": "3", "sales": 273.5} + ])"; expected_files_["/new_root/2019/1/CA/QC/dat_0"] = R"([ {"model": "S", "sales": 10}, - {"model": "3", "sales": 152.25}, + {"model": "Y", "sales": 37}, {"model": "X", "sales": 42}, - {"model": "Y", "sales": 37} - ])"; + {"model": "3", "sales": 152.25} + ])"; expected_physical_schema_ = SchemaFromColumnNames(source_schema_, {"model", "sales"}); AssertWrittenAsExpected(); @@ -1206,23 +1206,23 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { SchemaFromColumnNames(source_schema_, {}))); expected_files_["/new_root/dat_0"] = R"([ - {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "3", "sales": 742.0}, - {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "S", "sales": 304.125}, - {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "Y", "sales": 27.5}, - {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "3", "sales": 512}, - {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "S", "sales": 978}, - {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "X", "sales": 136.25}, {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "X", "sales": 1.0}, - {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "Y", "sales": 69}, - {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "3", "sales": 273.5}, - {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "S", "sales": 13}, - {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "X", "sales": 54}, {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "S", "sales": 10}, + {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "S", "sales": 13}, {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "Y", "sales": 21}, - {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "3", "sales": 152.25}, + {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "Y", "sales": 27.5}, + {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "Y", "sales": 37}, {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "X", "sales": 42}, - {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "Y", "sales": 37} - ])"; + {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "X", "sales": 54}, + {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "Y", "sales": 69}, + {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "X", "sales": 136.25}, + {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "3", "sales": 152.25}, + {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "3", "sales": 273.5}, + {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "S", "sales": 304.125}, + {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "3", "sales": 512}, + {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "3", "sales": 742.0}, + {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "S", "sales": 978} + ])"; expected_physical_schema_ = source_schema_; AssertWrittenAsExpected(); @@ -1270,7 +1270,12 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { for (auto maybe_batch : MakeIteratorFromReader(std::make_shared(*actual_table))) { ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); - ASSERT_OK_AND_ASSIGN(actual_struct, batch->ToStructArray()); + ASSERT_OK_AND_ASSIGN( + auto sort_indices, + compute::SortIndices(batch->GetColumnByName("sales"), + compute::SortOptions({compute::SortKey{"sales"}}))); + ASSERT_OK_AND_ASSIGN(Datum sorted_batch, compute::Take(batch, sort_indices)); + ASSERT_OK_AND_ASSIGN(actual_struct, sorted_batch.record_batch()->ToStructArray()); } auto expected_struct = ArrayFromJSON(struct_(expected_physical_schema_->fields()), diff --git a/cpp/src/arrow/util/async_util.cc b/cpp/src/arrow/util/async_util.cc index 76c971f576e..9407684bdda 100644 --- a/cpp/src/arrow/util/async_util.cc +++ b/cpp/src/arrow/util/async_util.cc @@ -41,7 +41,7 @@ void AsyncDestroyable::Destroy() { }); } -Status AsyncTaskGroup::AddTask(const Future<>& task) { +Status AsyncTaskGroup::AddTask(std::function>()> task) { auto guard = mutex_.Lock(); if (all_tasks_done_.is_finished()) { return Status::Invalid("Attempt to add a task after the task group has completed"); @@ -49,15 +49,25 @@ Status AsyncTaskGroup::AddTask(const Future<>& task) { if (!err_.ok()) { return err_; } + Result> maybe_task_fut = task(); + if (!maybe_task_fut.ok()) { + err_ = maybe_task_fut.status(); + return err_; + } + return AddTaskUnlocked(*maybe_task_fut, std::move(guard)); +} + +Status AsyncTaskGroup::AddTaskUnlocked(const Future<>& task_fut, + util::Mutex::Guard guard) { // If the task is already finished there is nothing to track so lets save // some work and return early - if (task.is_finished()) { - err_ &= task.status(); - return Status::OK(); + if (task_fut.is_finished()) { + err_ &= task_fut.status(); + return err_; } running_tasks_++; guard.Unlock(); - task.AddCallback([this](const Status& st) { + task_fut.AddCallback([this](const Status& st) { auto guard = mutex_.Lock(); err_ &= st; if (--running_tasks_ == 0 && finished_adding_) { @@ -68,7 +78,18 @@ Status AsyncTaskGroup::AddTask(const Future<>& task) { return Status::OK(); } -Future<> AsyncTaskGroup::WaitForTasksToFinish() { +Status AsyncTaskGroup::AddTask(const Future<>& task_fut) { + auto guard = mutex_.Lock(); + if (all_tasks_done_.is_finished()) { + return Status::Invalid("Attempt to add a task after the task group has completed"); + } + if (!err_.ok()) { + return err_; + } + return AddTaskUnlocked(task_fut, std::move(guard)); +} + +Future<> AsyncTaskGroup::End() { auto guard = mutex_.Lock(); finished_adding_ = true; if (running_tasks_ == 0) { @@ -78,5 +99,68 @@ Future<> AsyncTaskGroup::WaitForTasksToFinish() { return all_tasks_done_; } +Future<> AsyncTaskGroup::OnFinished() const { return all_tasks_done_; } + +SerializedAsyncTaskGroup::SerializedAsyncTaskGroup() : on_finished_(Future<>::Make()) {} + +Status SerializedAsyncTaskGroup::AddTask(std::function>()> task) { + util::Mutex::Guard guard = mutex_.Lock(); + ARROW_RETURN_NOT_OK(err_); + if (on_finished_.is_finished()) { + return Status::Invalid("Attempt to add a task after a task group has finished"); + } + tasks_.push(std::move(task)); + if (!processing_.is_valid()) { + ConsumeAsMuchAsPossibleUnlocked(std::move(guard)); + } + return err_; +} + +Future<> SerializedAsyncTaskGroup::End() { + util::Mutex::Guard guard = mutex_.Lock(); + ended_ = true; + if (!processing_.is_valid()) { + guard.Unlock(); + on_finished_.MarkFinished(err_); + } + return on_finished_; +} + +void SerializedAsyncTaskGroup::ConsumeAsMuchAsPossibleUnlocked( + util::Mutex::Guard&& guard) { + while (err_.ok() && !tasks_.empty() && TryDrainUnlocked()) { + } + if (ended_ && tasks_.empty() && !processing_.is_valid()) { + guard.Unlock(); + on_finished_.MarkFinished(err_); + } +} + +bool SerializedAsyncTaskGroup::TryDrainUnlocked() { + if (processing_.is_valid()) { + return false; + } + std::function>()> next_task = std::move(tasks_.front()); + tasks_.pop(); + Result> maybe_next_fut = next_task(); + if (!maybe_next_fut.ok()) { + err_ &= maybe_next_fut.status(); + return true; + } + Future<> next_fut = maybe_next_fut.MoveValueUnsafe(); + if (next_fut.is_finished()) { + err_ &= next_fut.status(); + return true; + } + processing_ = std::move(next_fut); + processing_.AddCallback([this](const Status& st) { + util::Mutex::Guard guard = mutex_.Lock(); + processing_ = Future<>(); + err_ &= st; + ConsumeAsMuchAsPossibleUnlocked(std::move(guard)); + }); + return false; +} + } // namespace util } // namespace arrow diff --git a/cpp/src/arrow/util/async_util.h b/cpp/src/arrow/util/async_util.h index 9b0efd9e030..daa6bad8cee 100644 --- a/cpp/src/arrow/util/async_util.h +++ b/cpp/src/arrow/util/async_util.h @@ -17,6 +17,8 @@ #pragma once +#include + #include "arrow/result.h" #include "arrow/status.h" #include "arrow/util/future.h" @@ -113,8 +115,10 @@ class ARROW_EXPORT AsyncTaskGroup { /// /// If WaitForTasksToFinish has been called and the returned future has been marked /// completed then adding a task will fail. + Status AddTask(std::function>()> task); + /// Add a task that has already been started Status AddTask(const Future<>& task); - /// A future that will be completed when all running tasks are finished. + /// Signal that top level tasks are done being added /// /// It is allowed for tasks to be added after this call provided the future has not yet /// completed. This should be safe as long as the tasks being added are added as part @@ -122,9 +126,22 @@ class ARROW_EXPORT AsyncTaskGroup { /// future will be marked complete. /// /// Any attempt to add a task after the returned future has completed will fail. - Future<> WaitForTasksToFinish(); + /// + /// The returned future that will finish when all running tasks have finsihed. + Future<> End(); + /// A future that will be finished after End is called and all tasks have completed + /// + /// This is the same future that is returned by End() but calling this method does + /// not indicate that top level tasks are done being added. End() must still be called + /// at some point or the future returned will never finish. + /// + /// This is a utility method for workflows where the finish future needs to be + /// referenced before all top level tasks have been queued. + Future<> OnFinished() const; private: + Status AddTaskUnlocked(const Future<>& task, util::Mutex::Guard guard); + bool finished_adding_ = false; int running_tasks_ = 0; Status err_; @@ -132,5 +149,51 @@ class ARROW_EXPORT AsyncTaskGroup { util::Mutex mutex_; }; +/// A task group which serializes asynchronous tasks in a push-based workflow +/// +/// Tasks will be executed in the order they are added +/// +/// This will buffer results in an unlimited fashion so it should be combined +/// with some kind of backpressure +class ARROW_EXPORT SerializedAsyncTaskGroup { + public: + SerializedAsyncTaskGroup(); + /// Push an item into the serializer and (eventually) into the consumer + /// + /// The item will not be delivered to the consumer until all previous items have been + /// consumed. + /// + /// If the consumer returns an error then this serializer will go into an error state + /// and all subsequent pushes will fail with that error. Pushes that have been queued + /// but not delivered will be silently dropped. + /// + /// \return True if the item was pushed immediately to the consumer, false if it was + /// queued + Status AddTask(std::function>()> task); + + /// Signal that all top level tasks have been added + /// + /// The returned future that will finish when all tasks have been consumed. + Future<> End(); + + /// A future that finishes when all queued items have been delivered. + /// + /// This will return the same future returned by End but will not signal + /// that all tasks have been finished. End must be called at some point in order for + /// this future to finish. + Future<> OnFinished() const { return on_finished_; } + + private: + void ConsumeAsMuchAsPossibleUnlocked(util::Mutex::Guard&& guard); + bool TryDrainUnlocked(); + + Future<> on_finished_; + std::queue>()>> tasks_; + util::Mutex mutex_; + bool ended_ = false; + Status err_; + Future<> processing_; +}; + } // namespace util } // namespace arrow diff --git a/cpp/src/arrow/util/async_util_test.cc b/cpp/src/arrow/util/async_util_test.cc index 9bae7977d45..eae4adfdfa1 100644 --- a/cpp/src/arrow/util/async_util_test.cc +++ b/cpp/src/arrow/util/async_util_test.cc @@ -96,13 +96,20 @@ TEST(AsyncDestroyable, MakeUnique) { }); } -TEST(AsyncTaskGroup, Basic) { - AsyncTaskGroup task_group; +template +class TypedTestAsyncTaskGroup : public ::testing::Test {}; + +using AsyncTaskGroupTypes = ::testing::Types; + +TYPED_TEST_SUITE(TypedTestAsyncTaskGroup, AsyncTaskGroupTypes); + +TYPED_TEST(TypedTestAsyncTaskGroup, Basic) { + TypeParam task_group; Future<> fut1 = Future<>::Make(); Future<> fut2 = Future<>::Make(); - ASSERT_OK(task_group.AddTask(fut1)); - ASSERT_OK(task_group.AddTask(fut2)); - Future<> all_done = task_group.WaitForTasksToFinish(); + ASSERT_OK(task_group.AddTask([fut1]() { return fut1; })); + ASSERT_OK(task_group.AddTask([fut2]() { return fut2; })); + Future<> all_done = task_group.End(); AssertNotFinished(all_done); fut1.MarkFinished(); AssertNotFinished(all_done); @@ -110,25 +117,33 @@ TEST(AsyncTaskGroup, Basic) { ASSERT_FINISHES_OK(all_done); } -TEST(AsyncTaskGroup, NoTasks) { - AsyncTaskGroup task_group; - ASSERT_FINISHES_OK(task_group.WaitForTasksToFinish()); +TYPED_TEST(TypedTestAsyncTaskGroup, NoTasks) { + TypeParam task_group; + ASSERT_FINISHES_OK(task_group.End()); } -TEST(AsyncTaskGroup, AddAfterDone) { - AsyncTaskGroup task_group; - ASSERT_FINISHES_OK(task_group.WaitForTasksToFinish()); - ASSERT_RAISES(Invalid, task_group.AddTask(Future<>::Make())); +TYPED_TEST(TypedTestAsyncTaskGroup, OnFinishedDoesNotEnd) { + TypeParam task_group; + Future<> on_finished = task_group.OnFinished(); + AssertNotFinished(on_finished); + ASSERT_FINISHES_OK(task_group.End()); + ASSERT_FINISHES_OK(on_finished); } -TEST(AsyncTaskGroup, AddAfterWaitButBeforeFinish) { - AsyncTaskGroup task_group; +TYPED_TEST(TypedTestAsyncTaskGroup, AddAfterDone) { + TypeParam task_group; + ASSERT_FINISHES_OK(task_group.End()); + ASSERT_RAISES(Invalid, task_group.AddTask([] { return Future<>::Make(); })); +} + +TYPED_TEST(TypedTestAsyncTaskGroup, AddAfterWaitButBeforeFinish) { + TypeParam task_group; Future<> task_one = Future<>::Make(); - ASSERT_OK(task_group.AddTask(task_one)); - Future<> finish_fut = task_group.WaitForTasksToFinish(); + ASSERT_OK(task_group.AddTask([task_one] { return task_one; })); + Future<> finish_fut = task_group.End(); AssertNotFinished(finish_fut); Future<> task_two = Future<>::Make(); - ASSERT_OK(task_group.AddTask(task_two)); + ASSERT_OK(task_group.AddTask([task_two] { return task_two; })); AssertNotFinished(finish_fut); task_one.MarkFinished(); AssertNotFinished(finish_fut); @@ -137,45 +152,88 @@ TEST(AsyncTaskGroup, AddAfterWaitButBeforeFinish) { ASSERT_FINISHES_OK(finish_fut); } -TEST(AsyncTaskGroup, Error) { - AsyncTaskGroup task_group; +TYPED_TEST(TypedTestAsyncTaskGroup, Error) { + TypeParam task_group; Future<> failed_task = Future<>::MakeFinished(Status::Invalid("XYZ")); - ASSERT_OK(task_group.AddTask(failed_task)); - ASSERT_FINISHES_AND_RAISES(Invalid, task_group.WaitForTasksToFinish()); + ASSERT_RAISES(Invalid, task_group.AddTask([failed_task] { return failed_task; })); + ASSERT_FINISHES_AND_RAISES(Invalid, task_group.End()); +} + +TYPED_TEST(TypedTestAsyncTaskGroup, TaskFactoryFails) { + TypeParam task_group; + ASSERT_RAISES(Invalid, task_group.AddTask([] { return Status::Invalid("XYZ"); })); + ASSERT_RAISES(Invalid, task_group.AddTask([] { return Future<>::Make(); })); + ASSERT_FINISHES_AND_RAISES(Invalid, task_group.End()); } -TEST(AsyncTaskGroup, TaskFinishesAfterError) { +TYPED_TEST(TypedTestAsyncTaskGroup, AddAfterFailed) { + TypeParam task_group; + ASSERT_RAISES(Invalid, task_group.AddTask([] { + return Future<>::MakeFinished(Status::Invalid("XYZ")); + })); + ASSERT_RAISES(Invalid, task_group.AddTask([] { return Future<>::Make(); })); + ASSERT_FINISHES_AND_RAISES(Invalid, task_group.End()); +} + +TEST(StandardAsyncTaskGroup, TaskFinishesAfterError) { AsyncTaskGroup task_group; Future<> fut1 = Future<>::Make(); - ASSERT_OK(task_group.AddTask(fut1)); - ASSERT_OK(task_group.AddTask(Future<>::MakeFinished(Status::Invalid("XYZ")))); - Future<> finished_fut = task_group.WaitForTasksToFinish(); + ASSERT_OK(task_group.AddTask([fut1] { return fut1; })); + ASSERT_RAISES(Invalid, task_group.AddTask([] { + return Future<>::MakeFinished(Status::Invalid("XYZ")); + })); + Future<> finished_fut = task_group.End(); AssertNotFinished(finished_fut); fut1.MarkFinished(); ASSERT_FINISHES_AND_RAISES(Invalid, finished_fut); } -TEST(AsyncTaskGroup, AddAfterFailed) { - AsyncTaskGroup task_group; - ASSERT_OK(task_group.AddTask(Future<>::MakeFinished(Status::Invalid("XYZ")))); - ASSERT_RAISES(Invalid, task_group.AddTask(Future<>::Make())); - ASSERT_FINISHES_AND_RAISES(Invalid, task_group.WaitForTasksToFinish()); -} - -TEST(AsyncTaskGroup, FailAfterAdd) { +TEST(StandardAsyncTaskGroup, FailAfterAdd) { AsyncTaskGroup task_group; Future<> will_fail = Future<>::Make(); - ASSERT_OK(task_group.AddTask(will_fail)); + ASSERT_OK(task_group.AddTask([will_fail] { return will_fail; })); Future<> added_later_and_passes = Future<>::Make(); - ASSERT_OK(task_group.AddTask(added_later_and_passes)); + ASSERT_OK( + task_group.AddTask([added_later_and_passes] { return added_later_and_passes; })); will_fail.MarkFinished(Status::Invalid("XYZ")); - ASSERT_RAISES(Invalid, task_group.AddTask(Future<>::Make())); - Future<> finished_fut = task_group.WaitForTasksToFinish(); + ASSERT_RAISES(Invalid, task_group.AddTask([] { return Future<>::Make(); })); + Future<> finished_fut = task_group.End(); AssertNotFinished(finished_fut); added_later_and_passes.MarkFinished(); AssertFinished(finished_fut); ASSERT_FINISHES_AND_RAISES(Invalid, finished_fut); } +// The serialized task group can never really get into a "fail after add" scenario +// because there is no parallelism. So the behavior is a little unique in these scenarios + +TEST(SerializedAsyncTaskGroup, TaskFinishesAfterError) { + SerializedAsyncTaskGroup task_group; + Future<> fut1 = Future<>::Make(); + ASSERT_OK(task_group.AddTask([fut1] { return fut1; })); + ASSERT_OK( + task_group.AddTask([] { return Future<>::MakeFinished(Status::Invalid("XYZ")); })); + Future<> finished_fut = task_group.End(); + AssertNotFinished(finished_fut); + fut1.MarkFinished(); + ASSERT_FINISHES_AND_RAISES(Invalid, finished_fut); +} + +TEST(SerializedAsyncTaskGroup, FailAfterAdd) { + SerializedAsyncTaskGroup task_group; + Future<> will_fail = Future<>::Make(); + ASSERT_OK(task_group.AddTask([will_fail] { return will_fail; })); + Future<> added_later_and_passes = Future<>::Make(); + bool added_later_and_passes_created = false; + ASSERT_OK(task_group.AddTask([added_later_and_passes, &added_later_and_passes_created] { + added_later_and_passes_created = true; + return added_later_and_passes; + })); + will_fail.MarkFinished(Status::Invalid("XYZ")); + ASSERT_RAISES(Invalid, task_group.AddTask([] { return Future<>::Make(); })); + ASSERT_FINISHES_AND_RAISES(Invalid, task_group.End()); + ASSERT_FALSE(added_later_and_passes_created); +} + } // namespace util } // namespace arrow diff --git a/cpp/src/arrow/util/future.cc b/cpp/src/arrow/util/future.cc index fc8022a95e4..c398d992861 100644 --- a/cpp/src/arrow/util/future.cc +++ b/cpp/src/arrow/util/future.cc @@ -423,4 +423,15 @@ Future<> AllComplete(const std::vector>& futures) { return out; } +Future<> AllFinished(const std::vector>& futures) { + return All(futures).Then([](const std::vector>& results) { + for (const auto& res : results) { + if (!res.ok()) { + return res.status(); + } + } + return Status::OK(); + }); +} + } // namespace arrow diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index 6c194cab2ac..695ee9ff357 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -840,6 +840,17 @@ inline Future<>::Future(Status s) : Future(internal::Empty::ToResult(std::move(s ARROW_EXPORT Future<> AllComplete(const std::vector>& futures); +/// \brief Create a Future which completes when all of `futures` complete. +/// +/// The future will finish with an ok status if all `futures` finish with +/// an ok status. Otherwise, it will be marked failed with the status of +/// one of the failing futures. +/// +/// Unlike AllComplete this Future will not complete immediately when a +/// failure occurs. It will wait until all futures have finished. +ARROW_EXPORT +Future<> AllFinished(const std::vector>& futures); + /// \brief Wait for one of the futures to end, or for the given timeout to expire. /// /// The indices of all completed futures are returned. Note that some futures diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 6d84c8cce21..ffa1f28e634 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -3231,7 +3231,14 @@ def test_legacy_write_to_dataset_drops_null(tempdir): assert actual == expected -def _check_dataset_roundtrip(dataset, base_dir, expected_files, +def _sort_table(tab, sort_col): + import pyarrow.compute as pc + sorted_indices = pc.sort_indices( + tab, options=pc.SortOptions([(sort_col, 'ascending')])) + return pc.take(tab, sorted_indices) + + +def _check_dataset_roundtrip(dataset, base_dir, expected_files, sort_col, base_dir_path=None, partitioning=None): base_dir_path = base_dir_path or base_dir @@ -3245,7 +3252,9 @@ def _check_dataset_roundtrip(dataset, base_dir, expected_files, # check that reading back in as dataset gives the same result dataset2 = ds.dataset( base_dir_path, format="feather", partitioning=partitioning) - assert dataset2.to_table().equals(dataset.to_table()) + + assert _sort_table(dataset2.to_table(), sort_col).equals( + _sort_table(dataset.to_table(), sort_col)) @pytest.mark.parquet @@ -3259,12 +3268,12 @@ def test_write_dataset(tempdir): # full string path target = tempdir / 'single-file-target' expected_files = [target / "part-0.feather"] - _check_dataset_roundtrip(dataset, str(target), expected_files, target) + _check_dataset_roundtrip(dataset, str(target), expected_files, 'a', target) # pathlib path object target = tempdir / 'single-file-target2' expected_files = [target / "part-0.feather"] - _check_dataset_roundtrip(dataset, target, expected_files, target) + _check_dataset_roundtrip(dataset, target, expected_files, 'a', target) # TODO # # relative path @@ -3281,7 +3290,7 @@ def test_write_dataset(tempdir): target = tempdir / 'single-directory-target' expected_files = [target / "part-0.feather"] - _check_dataset_roundtrip(dataset, str(target), expected_files, target) + _check_dataset_roundtrip(dataset, str(target), expected_files, 'a', target) @pytest.mark.parquet @@ -3301,7 +3310,7 @@ def test_write_dataset_partitioned(tempdir): partitioning_schema = ds.partitioning( pa.schema([("part", pa.string())]), flavor="hive") _check_dataset_roundtrip( - dataset, str(target), expected_paths, target, + dataset, str(target), expected_paths, 'f1', target, partitioning=partitioning_schema) # directory partitioning @@ -3313,7 +3322,7 @@ def test_write_dataset_partitioned(tempdir): partitioning_schema = ds.partitioning( pa.schema([("part", pa.string())])) _check_dataset_roundtrip( - dataset, str(target), expected_paths, target, + dataset, str(target), expected_paths, 'f1', target, partitioning=partitioning_schema) @@ -3409,7 +3418,7 @@ def test_write_dataset_partitioned_dict(tempdir): # directories in _check_dataset_roundtrip (not currently required for # the formatting step) _check_dataset_roundtrip( - dataset, str(target), expected_paths, target, + dataset, str(target), expected_paths, 'f1', target, partitioning=partitioning) diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index c61f7a3d12f..39968b0ae32 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -96,6 +96,7 @@ std::shared_ptr ExecPlan_run( #if defined(ARROW_R_WITH_DATASET) +#include #include // [[dataset::export]]