diff --git a/cpp/src/arrow/csv/reader.cc b/cpp/src/arrow/csv/reader.cc index 0e86df26ad8..c4352360e6b 100644 --- a/cpp/src/arrow/csv/reader.cc +++ b/cpp/src/arrow/csv/reader.cc @@ -1001,9 +1001,8 @@ Result> MakeTableReader( Future> MakeStreamingReader( io::IOContext io_context, std::shared_ptr input, - const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options) { - auto cpu_executor = internal::GetCpuThreadPool(); + internal::Executor* cpu_executor, const ReadOptions& read_options, + const ParseOptions& parse_options, const ConvertOptions& convert_options) { std::shared_ptr reader; reader = std::make_shared( io_context, cpu_executor, input, read_options, parse_options, convert_options); @@ -1036,8 +1035,9 @@ Result> StreamingReader::Make( const ReadOptions& read_options, const ParseOptions& parse_options, const ConvertOptions& convert_options) { auto io_context = io::IOContext(pool); - auto reader_fut = MakeStreamingReader(io_context, std::move(input), read_options, - parse_options, convert_options); + auto cpu_executor = internal::GetCpuThreadPool(); + auto reader_fut = MakeStreamingReader(io_context, std::move(input), cpu_executor, + read_options, parse_options, convert_options); auto reader_result = reader_fut.result(); ARROW_ASSIGN_OR_RAISE(auto reader, reader_result); return reader; @@ -1047,8 +1047,9 @@ Result> StreamingReader::Make( io::IOContext io_context, std::shared_ptr input, const ReadOptions& read_options, const ParseOptions& parse_options, const ConvertOptions& convert_options) { - auto reader_fut = MakeStreamingReader(io_context, std::move(input), read_options, - parse_options, convert_options); + auto cpu_executor = internal::GetCpuThreadPool(); + auto reader_fut = MakeStreamingReader(io_context, std::move(input), cpu_executor, + read_options, parse_options, convert_options); auto reader_result = reader_fut.result(); ARROW_ASSIGN_OR_RAISE(auto reader, reader_result); return reader; @@ -1056,10 +1057,10 @@ Result> StreamingReader::Make( Future> StreamingReader::MakeAsync( io::IOContext io_context, std::shared_ptr input, - const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options) { - return MakeStreamingReader(io_context, std::move(input), read_options, parse_options, - convert_options); + internal::Executor* cpu_executor, const ReadOptions& read_options, + const ParseOptions& parse_options, const ConvertOptions& convert_options) { + return MakeStreamingReader(io_context, std::move(input), cpu_executor, read_options, + parse_options, convert_options); } } // namespace csv diff --git a/cpp/src/arrow/csv/reader.h b/cpp/src/arrow/csv/reader.h index 79015e941ee..72f1375cc3c 100644 --- a/cpp/src/arrow/csv/reader.h +++ b/cpp/src/arrow/csv/reader.h @@ -26,6 +26,7 @@ #include "arrow/type.h" #include "arrow/type_fwd.h" #include "arrow/util/future.h" +#include "arrow/util/thread_pool.h" #include "arrow/util/visibility.h" namespace arrow { @@ -72,7 +73,8 @@ class ARROW_EXPORT StreamingReader : public RecordBatchReader { /// parsing (see ARROW-11889) static Future> MakeAsync( io::IOContext io_context, std::shared_ptr input, - const ReadOptions&, const ParseOptions&, const ConvertOptions&); + internal::Executor* cpu_executor, const ReadOptions&, const ParseOptions&, + const ConvertOptions&); static Result> Make( io::IOContext io_context, std::shared_ptr input, diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 8437c75ae1c..ad19bd2041e 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -418,14 +418,46 @@ Status WriteNextBatch(WriteState& state, const std::shared_ptr& scan_t return Status::OK(); } +Future<> WriteInternal(const ScanOptions& scan_options, WriteState& state, + ScanTaskVector scan_tasks, internal::Executor* cpu_executor) { + // Store a mapping from partitions (represened by their formatted partition expressions) + // to a WriteQueue which flushes batches into that partition's output file. In principle + // any thread could produce a batch for any partition, so each task alternates between + // pushing batches and flushing them to disk. + std::vector> scan_futs; + auto task_group = scan_options.TaskGroup(); + + for (const auto& scan_task : scan_tasks) { + if (scan_task->supports_async()) { + ARROW_ASSIGN_OR_RAISE(auto batches_gen, scan_task->ExecuteAsync(cpu_executor)); + std::function batch)> batch_visitor = + [&, scan_task](std::shared_ptr batch) { + return WriteNextBatch(state, scan_task, std::move(batch)); + }; + scan_futs.push_back(VisitAsyncGenerator(batches_gen, batch_visitor)); + } else { + task_group->Append([&, scan_task] { + ARROW_ASSIGN_OR_RAISE(auto batches, scan_task->Execute()); + + for (auto maybe_batch : batches) { + ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch); + RETURN_NOT_OK(WriteNextBatch(state, scan_task, std::move(batch))); + } + + return Status::OK(); + }); + } + } + scan_futs.push_back(task_group->FinishAsync()); + return AllComplete(scan_futs); +} + } // namespace Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_options, std::shared_ptr scanner) { RETURN_NOT_OK(ValidateBasenameTemplate(write_options.basename_template)); - auto task_group = scanner->options()->TaskGroup(); - // Things we'll un-lazy for the sake of simplicity, with the tradeoff they represent: // // - Fragment iteration. Keeping this lazy would allow us to start partitioning/writing @@ -440,7 +472,6 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio ARROW_ASSIGN_OR_RAISE(auto fragment_it, scanner->GetFragments()); ARROW_ASSIGN_OR_RAISE(FragmentVector fragments, fragment_it.ToVector()); ScanTaskVector scan_tasks; - std::vector> scan_futs; for (const auto& fragment : fragments) { auto options = std::make_shared(*scanner->options()); @@ -454,38 +485,16 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio } } - // Store a mapping from partitions (represened by their formatted partition expressions) - // to a WriteQueue which flushes batches into that partition's output file. In principle - // any thread could produce a batch for any partition, so each task alternates between - // pushing batches and flushing them to disk. WriteState state(write_options); + auto res = internal::RunSynchronously( + [&](internal::Executor* cpu_executor) -> Future<> { + return WriteInternal(*scanner->options(), state, std::move(scan_tasks), + cpu_executor); + }, + scanner->options()->use_threads); + RETURN_NOT_OK(res); - for (const auto& scan_task : scan_tasks) { - if (scan_task->supports_async()) { - ARROW_ASSIGN_OR_RAISE(auto batches_gen, scan_task->ExecuteAsync()); - std::function batch)> batch_visitor = - [&, scan_task](std::shared_ptr batch) { - return WriteNextBatch(state, scan_task, std::move(batch)); - }; - scan_futs.push_back(VisitAsyncGenerator(batches_gen, batch_visitor)); - } else { - task_group->Append([&, scan_task] { - ARROW_ASSIGN_OR_RAISE(auto batches, scan_task->Execute()); - - for (auto maybe_batch : batches) { - ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch); - RETURN_NOT_OK(WriteNextBatch(state, scan_task, std::move(batch))); - } - - return Status::OK(); - }); - } - } - RETURN_NOT_OK(task_group->Finish()); - auto scan_futs_all_done = AllComplete(scan_futs); - RETURN_NOT_OK(scan_futs_all_done.status()); - - task_group = scanner->options()->TaskGroup(); + auto task_group = scanner->options()->TaskGroup(); for (const auto& part_queue : state.queues) { task_group->Append([&] { return part_queue.second->writer()->Finish(); }); } diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index b55c23dfdef..677d1be05b7 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -43,6 +43,8 @@ namespace dataset { using internal::checked_cast; using internal::checked_pointer_cast; +using internal::Executor; +using internal::SerialExecutor; using RecordBatchGenerator = AsyncGenerator>; Result> GetColumnNames( @@ -107,13 +109,14 @@ static inline Result GetReadOptions( auto read_options = csv_scan_options->read_options; // Multithreaded conversion of individual files would lead to excessive thread // contention when ScanTasks are also executed in multiple threads, so we disable it - // here. + // here. Also, this is a no-op since the streaming CSV reader is currently serial read_options.use_threads = false; return read_options; } static inline Future> OpenReaderAsync( const FileSource& source, const CsvFileFormat& format, + internal::Executor* cpu_executor, const std::shared_ptr& scan_options = nullptr, MemoryPool* pool = default_memory_pool()) { ARROW_ASSIGN_OR_RAISE(auto reader_options, GetReadOptions(format, scan_options)); @@ -136,7 +139,8 @@ static inline Future> OpenReaderAsync( } return csv::StreamingReader::MakeAsync(io::default_io_context(), std::move(input), - reader_options, parse_options, convert_options) + cpu_executor, reader_options, parse_options, + convert_options) .Then( [](const std::shared_ptr& maybe_reader) -> Result> { return maybe_reader; }, @@ -151,8 +155,12 @@ static inline Result> OpenReader( const FileSource& source, const CsvFileFormat& format, const std::shared_ptr& scan_options = nullptr, MemoryPool* pool = default_memory_pool()) { - auto open_reader_fut = OpenReaderAsync(source, format, scan_options, pool); - return open_reader_fut.result(); + bool use_threads = (scan_options != nullptr && scan_options->use_threads); + return internal::RunSynchronously>( + [&](Executor* executor) { + return OpenReaderAsync(source, format, executor, scan_options, pool); + }, + use_threads); } /// \brief A ScanTask backed by an Csv file. @@ -166,14 +174,15 @@ class CsvScanTask : public ScanTask { source_(fragment->source()) {} Result Execute() override { - ARROW_ASSIGN_OR_RAISE(auto gen, ExecuteAsync()); + ARROW_ASSIGN_OR_RAISE(auto gen, ExecuteAsync(internal::GetCpuThreadPool())); return MakeGeneratorIterator(std::move(gen)); } bool supports_async() const override { return true; } - Result ExecuteAsync() override { - auto reader_fut = OpenReaderAsync(source_, *format_, options(), options()->pool); + Result ExecuteAsync(internal::Executor* cpu_executor) override { + auto reader_fut = + OpenReaderAsync(source_, *format_, cpu_executor, options(), options()->pool); auto generator_fut = reader_fut.Then( [](const std::shared_ptr& reader) -> RecordBatchGenerator { return [reader]() { return reader->ReadNextAsync(); }; diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 2258a10d141..a8ac24b7799 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -26,6 +26,7 @@ #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/table.h" +#include "arrow/util/async_generator.h" #include "arrow/util/iterator.h" #include "arrow/util/logging.h" #include "arrow/util/task_group.h" @@ -47,6 +48,8 @@ std::vector ScanOptions::MaterializedFields() const { return fields; } +using arrow::internal::Executor; +using arrow::internal::SerialExecutor; using arrow::internal::TaskGroup; std::shared_ptr ScanOptions::TaskGroup() const { @@ -61,7 +64,7 @@ Result InMemoryScanTask::Execute() { return MakeVectorIterator(record_batches_); } -Result ScanTask::ExecuteAsync() { +Result ScanTask::ExecuteAsync(internal::Executor*) { return Status::NotImplemented("Async is not implemented for this scan task yet"); } @@ -200,6 +203,13 @@ struct TableAssemblyState { }; Result> Scanner::ToTable() { + return internal::RunSynchronously>( + [this](Executor* executor) { return ToTableInternal(executor); }, + scan_options_->use_threads); +} + +Future> Scanner::ToTableInternal( + internal::Executor* cpu_executor) { ARROW_ASSIGN_OR_RAISE(auto scan_task_it, Scan()); auto task_group = scan_options_->TaskGroup(); @@ -215,7 +225,7 @@ Result> Scanner::ToTable() { auto id = scan_task_id++; if (scan_task->supports_async()) { - ARROW_ASSIGN_OR_RAISE(auto scan_gen, scan_task->ExecuteAsync()); + ARROW_ASSIGN_OR_RAISE(auto scan_gen, scan_task->ExecuteAsync(cpu_executor)); auto scan_fut = CollectAsyncGenerator(std::move(scan_gen)) .Then([state, id](const RecordBatchVector& rbs) { state->Emplace(rbs, id); @@ -230,14 +240,16 @@ Result> Scanner::ToTable() { }); } } - // Wait for all async tasks to complete, or the first error - RETURN_NOT_OK(AllComplete(scan_futures).status()); - - // Wait for all sync tasks to complete, or the first error. - RETURN_NOT_OK(task_group->Finish()); - - return Table::FromRecordBatches(scan_options_->projected_schema, - FlattenRecordBatchVector(std::move(state->batches))); + auto scan_options = scan_options_; + scan_futures.push_back(task_group->FinishAsync()); + // Wait for all tasks to complete, or the first error + return AllComplete(scan_futures) + .Then( + [scan_options, state](const detail::Empty&) -> Result> { + return Table::FromRecordBatches( + scan_options->projected_schema, + FlattenRecordBatchVector(std::move(state->batches))); + }); } } // namespace dataset diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index c3cce00d8c5..9bd4b10847b 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -19,6 +19,7 @@ #pragma once +#include #include #include #include @@ -31,11 +32,12 @@ #include "arrow/dataset/visibility.h" #include "arrow/memory_pool.h" #include "arrow/type_fwd.h" -#include "arrow/util/async_generator.h" #include "arrow/util/type_fwd.h" namespace arrow { -using RecordBatchGenerator = AsyncGenerator>; + +using RecordBatchGenerator = std::function>()>; + namespace dataset { constexpr int64_t kDefaultBatchSize = 1 << 20; @@ -103,7 +105,7 @@ class ARROW_DS_EXPORT ScanTask { /// resulting from the Scan. Execution semantics are encapsulated in the /// particular ScanTask implementation virtual Result Execute() = 0; - virtual Result ExecuteAsync(); + virtual Result ExecuteAsync(internal::Executor* cpu_executor); virtual bool supports_async() const; virtual ~ScanTask() = default; @@ -175,6 +177,8 @@ class ARROW_DS_EXPORT Scanner { const std::shared_ptr& options() const { return scan_options_; } protected: + Future> ToTableInternal(internal::Executor* cpu_executor); + std::shared_ptr dataset_; // TODO(ARROW-8065) remove fragment_ after a Dataset is constuctible from fragments std::shared_ptr fragment_; diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index 3101be477fd..d334c094d31 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -28,11 +28,13 @@ #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/scanner.h" +#include "arrow/util/async_generator.h" #include "arrow/util/logging.h" namespace arrow { using internal::checked_cast; +using internal::Executor; namespace dataset { @@ -171,22 +173,15 @@ class FilterAndProjectScanTask : public ScanTask { options_->pool); } - Result Execute() override { - if (task_->supports_async()) { - ARROW_ASSIGN_OR_RAISE(auto gen, ExecuteAsync()); - return MakeGeneratorIterator(std::move(gen)); - } else { - return ExecuteSync(); - } - } + Result Execute() override { return ExecuteSync(); } - Result ExecuteAsync() override { + Result ExecuteAsync(Executor* cpu_executor) override { if (!task_->supports_async()) { return Status::Invalid( "ExecuteAsync should not have been called on FilterAndProjectScanTask if the " "source task did not support async"); } - ARROW_ASSIGN_OR_RAISE(auto gen, task_->ExecuteAsync()); + ARROW_ASSIGN_OR_RAISE(auto gen, task_->ExecuteAsync(cpu_executor)); ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, SimplifyWithGuarantee(options()->filter, partition_)); diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 86bb14b038d..a6e761cf8c5 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -813,15 +813,15 @@ class NestedParallelismMixin : public ::testing::Test { // supports_async() to false (will deadlock) ADD_FAILURE() << "NestedParallelismScanTask::Execute should never be called. You " "should be deadlocked right now"; - ARROW_ASSIGN_OR_RAISE(auto batch_gen, ExecuteAsync()); + ARROW_ASSIGN_OR_RAISE(auto batch_gen, ExecuteAsync(internal::GetCpuThreadPool())); return MakeGeneratorIterator(std::move(batch_gen)); } - Result ExecuteAsync() override { + Result ExecuteAsync(internal::Executor* cpu_executor) override { ARROW_ASSIGN_OR_RAISE(auto batches_it, target_->Execute()); ARROW_ASSIGN_OR_RAISE(auto batches, batches_it.ToVector()); - auto generator_fut = DeferNotOk(internal::GetCpuThreadPool()->Submit( - [batches] { return MakeVectorGenerator(batches); })); + auto generator_fut = DeferNotOk( + cpu_executor->Submit([batches] { return MakeVectorGenerator(batches); })); return MakeFromFuture(generator_fut); } diff --git a/cpp/src/arrow/util/thread_pool.cc b/cpp/src/arrow/util/thread_pool.cc index f2a8368d273..873b9335e74 100644 --- a/cpp/src/arrow/util/thread_pool.cc +++ b/cpp/src/arrow/util/thread_pool.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -44,6 +45,64 @@ struct Task { } // namespace +struct SerialExecutor::State { + std::queue task_queue; + std::mutex mutex; + std::condition_variable wait_for_tasks; + bool finished; +}; + +SerialExecutor::SerialExecutor() : state_(new State()) {} +SerialExecutor::~SerialExecutor() {} + +Status SerialExecutor::SpawnReal(TaskHints hints, FnOnce task, + StopToken stop_token, StopCallback&& stop_callback) { + // The serial task queue is truly serial (no mutex needed) but SpawnReal may be called + // from external threads (e.g. when transferring back from blocking I/O threads) so a + // mutex is needed + { + std::lock_guard lg(state_->mutex); + state_->task_queue.push( + Task{std::move(task), std::move(stop_token), std::move(stop_callback)}); + } + state_->wait_for_tasks.notify_one(); + return Status::OK(); +} + +void SerialExecutor::MarkFinished() { + std::lock_guard lk(state_->mutex); + state_->finished = true; + // Keep the lock when notifying to avoid situations where the SerialExecutor + // would start being destroyed while the notify_one() call is still ongoing. + state_->wait_for_tasks.notify_one(); +} + +void SerialExecutor::RunLoop() { + std::unique_lock lk(state_->mutex); + + while (!state_->finished) { + while (!state_->task_queue.empty()) { + Task task = std::move(state_->task_queue.front()); + state_->task_queue.pop(); + lk.unlock(); + if (!task.stop_token.IsStopRequested()) { + std::move(task.callable)(); + } else { + if (task.stop_callback) { + std::move(task.stop_callback)(task.stop_token.Poll()); + } + // Can't break here because there may be cleanup tasks down the chain we still + // need to run. + } + lk.lock(); + } + // In this case we must be waiting on work from external (e.g. I/O) executors. Wait + // for tasks to arrive (typically via transferred futures). + state_->wait_for_tasks.wait( + lk, [&] { return state_->finished || !state_->task_queue.empty(); }); + } +} + struct ThreadPool::State { State() = default; @@ -350,6 +409,11 @@ ThreadPool* GetCpuThreadPool() { return singleton.get(); } +Status RunSynchronouslyVoid(FnOnce(Executor*)> get_future, + bool use_threads) { + return RunSynchronously(std::move(get_future), use_threads).status(); +} + } // namespace internal int GetCpuThreadPoolCapacity() { return internal::GetCpuThreadPool()->GetCapacity(); } diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index 0abe381f100..c4d4d1869c6 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -189,8 +190,63 @@ class ARROW_EXPORT Executor { StopCallback&&) = 0; }; -// An Executor implementation spawning tasks in FIFO manner on a fixed-size -// pool of worker threads. +/// \brief An executor implementation that runs all tasks on a single thread using an +/// event loop. +/// +/// Note: Any sort of nested parallelism will deadlock this executor. Blocking waits are +/// fine but if one task needs to wait for another task it must be expressed as an +/// asynchronous continuation. +class ARROW_EXPORT SerialExecutor : public Executor { + public: + template + using TopLevelTask = internal::FnOnce(Executor*)>; + + ~SerialExecutor(); + + int GetCapacity() override { return 1; }; + Status SpawnReal(TaskHints hints, FnOnce task, StopToken, + StopCallback&&) override; + + /// \brief Runs the TopLevelTask and any scheduled tasks + /// + /// The TopLevelTask (or one of the tasks it schedules) must either return an invalid + /// status or call the finish signal. Failure to do this will result in a deadlock. For + /// this reason it is preferable (if possible) to use the helper methods (below) + /// RunSynchronously/RunSerially which delegates the responsiblity onto a Future + /// producer's existing responsibility to always mark a future finished (which can + /// someday be aided by ARROW-12207). + template + static Result RunInSerialExecutor(TopLevelTask initial_task) { + return SerialExecutor().Run(std::move(initial_task)); + } + + private: + SerialExecutor(); + + // State uses mutex + struct State; + std::unique_ptr state_; + + template + Result Run(TopLevelTask initial_task) { + auto final_fut = std::move(initial_task)(this); + if (final_fut.is_finished()) { + return final_fut.result(); + } + final_fut.AddCallback([this](const Result&) { MarkFinished(); }); + RunLoop(); + return final_fut.result(); + } + void RunLoop(); + void MarkFinished(); +}; + +/// An Executor implementation spawning tasks in FIFO manner on a fixed-size +/// pool of worker threads. +/// +/// Note: Any sort of nested parallelism will deadlock this executor. Blocking waits are +/// fine but if one task needs to wait for another task it must be expressed as an +/// asynchronous continuation. class ARROW_EXPORT ThreadPool : public Executor { public: // Construct a thread pool with the given number of worker threads @@ -262,5 +318,24 @@ class ARROW_EXPORT ThreadPool : public Executor { // Return the process-global thread pool for CPU-bound tasks. ARROW_EXPORT ThreadPool* GetCpuThreadPool(); +/// \brief Potentially run an async operation serially (if use_threads is false) +/// \see RunSerially +/// +/// If `use_threads` is true, the global CPU executor is used. +/// If `use_threads` is false, a temporary SerialExecutor is used. +/// `get_future` is called (from this thread) with the chosen executor and must +/// return a future that will eventually finish. This function returns once the +/// future has finished. +template +Result RunSynchronously(FnOnce(Executor*)> get_future, bool use_threads) { + if (use_threads) { + return std::move(get_future)(GetCpuThreadPool()).result(); + } else { + return SerialExecutor::RunInSerialExecutor(std::move(get_future)); + } +} + +ARROW_EXPORT Status RunSynchronouslyVoid( + FnOnce(Executor*)> get_future, bool use_threads); } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/thread_pool_test.cc b/cpp/src/arrow/util/thread_pool_test.cc index 6f686ee650b..2390f8c1a41 100644 --- a/cpp/src/arrow/util/thread_pool_test.cc +++ b/cpp/src/arrow/util/thread_pool_test.cc @@ -123,6 +123,128 @@ class AddTester { std::vector outs_; }; +class TestRunSynchronously : public testing::TestWithParam { + public: + bool UseThreads() { return GetParam(); } + + template + Result Run(FnOnce(Executor*)> top_level_task) { + return RunSynchronously(std::move(top_level_task), UseThreads()); + } + + Status RunVoid(FnOnce(Executor*)> top_level_task) { + return RunSynchronouslyVoid(std::move(top_level_task), UseThreads()); + } +}; + +TEST_P(TestRunSynchronously, SimpleRun) { + bool task_ran = false; + auto task = [&](Executor* executor) { + EXPECT_NE(executor, nullptr); + task_ran = true; + return Future<>::MakeFinished(Status::OK()); + }; + ASSERT_OK(RunVoid(std::move(task))); + EXPECT_TRUE(task_ran); +} + +TEST_P(TestRunSynchronously, SpawnNested) { + bool nested_ran = false; + auto top_level_task = [&](Executor* executor) { + return DeferNotOk(executor->Submit([&] { + nested_ran = true; + return Status::OK(); + })); + }; + ASSERT_OK(RunVoid(std::move(top_level_task))); + EXPECT_TRUE(nested_ran); +} + +TEST_P(TestRunSynchronously, SpawnMoreNested) { + std::atomic nested_ran{0}; + auto top_level_task = [&](Executor* executor) -> Future<> { + auto fut_a = DeferNotOk(executor->Submit([&] { nested_ran++; })); + auto fut_b = DeferNotOk(executor->Submit([&] { nested_ran++; })); + return AllComplete({fut_a, fut_b}) + .Then([&](const Result& result) { + nested_ran++; + return result; + }); + }; + ASSERT_OK(RunVoid(std::move(top_level_task))); + EXPECT_EQ(nested_ran, 3); +} + +TEST_P(TestRunSynchronously, WithResult) { + auto top_level_task = [&](Executor* executor) { + return DeferNotOk(executor->Submit([] { return 42; })); + }; + ASSERT_OK_AND_EQ(42, Run(std::move(top_level_task))); +} + +TEST_P(TestRunSynchronously, StopTokenSpawn) { + bool nested_ran = false; + StopSource stop_source; + auto top_level_task = [&](Executor* executor) -> Future<> { + stop_source.RequestStop(Status::Invalid("XYZ")); + RETURN_NOT_OK(executor->Spawn([&] { nested_ran = true; }, stop_source.token())); + return Future<>::MakeFinished(); + }; + ASSERT_OK(RunVoid(std::move(top_level_task))); + EXPECT_FALSE(nested_ran); +} + +TEST_P(TestRunSynchronously, StopTokenSubmit) { + bool nested_ran = false; + StopSource stop_source; + auto top_level_task = [&](Executor* executor) -> Future<> { + stop_source.RequestStop(); + return DeferNotOk(executor->Submit(stop_source.token(), [&] { + nested_ran = true; + return Status::OK(); + })); + }; + ASSERT_RAISES(Cancelled, RunVoid(std::move(top_level_task))); + EXPECT_FALSE(nested_ran); +} + +TEST_P(TestRunSynchronously, ContinueAfterExternal) { + bool continuation_ran = false; + EXPECT_OK_AND_ASSIGN(auto mock_io_pool, ThreadPool::Make(1)); + auto top_level_task = [&](Executor* executor) { + struct Callback { + Status operator()(...) { + continuation_ran = true; + return Status::OK(); + } + bool& continuation_ran; + }; + return executor + ->Transfer(DeferNotOk(mock_io_pool->Submit([&] { + SleepABit(); + return Status::OK(); + }))) + .Then(Callback{continuation_ran}); + }; + ASSERT_OK(RunVoid(std::move(top_level_task))); + EXPECT_TRUE(continuation_ran); +} + +TEST_P(TestRunSynchronously, SchedulerAbort) { + auto top_level_task = [&](Executor* executor) { return Status::Invalid("XYZ"); }; + ASSERT_RAISES(Invalid, RunVoid(std::move(top_level_task))); +} + +TEST_P(TestRunSynchronously, PropagatedError) { + auto top_level_task = [&](Executor* executor) { + return DeferNotOk(executor->Submit([] { return Status::Invalid("XYZ"); })); + }; + ASSERT_RAISES(Invalid, RunVoid(std::move(top_level_task))); +} + +INSTANTIATE_TEST_SUITE_P(TestRunSynchronously, TestRunSynchronously, + ::testing::Values(false, true)); + class TestThreadPool : public ::testing::Test { public: void TearDown() override {