diff --git a/cpp/src/arrow/testing/async_test_util.h b/cpp/src/arrow/testing/async_test_util.h index b9f5487ed0d..7066bbb63d2 100644 --- a/cpp/src/arrow/testing/async_test_util.h +++ b/cpp/src/arrow/testing/async_test_util.h @@ -20,12 +20,37 @@ #include #include +#include "arrow/testing/gtest_util.h" #include "arrow/util/async_generator.h" #include "arrow/util/future.h" namespace arrow { namespace util { +template +AsyncGenerator AsyncVectorIt(std::vector v) { + return MakeVectorGenerator(std::move(v)); +} + +template +AsyncGenerator FailAt(AsyncGenerator src, int failing_index) { + auto index = std::make_shared>(0); + return [src, index, failing_index]() { + auto idx = index->fetch_add(1); + if (idx >= failing_index) { + return Future::MakeFinished(Status::Invalid("XYZ")); + } + return src(); + }; +} + +template +AsyncGenerator SlowdownABit(AsyncGenerator source) { + return MakeMappedGenerator(std::move(source), [](const T& res) { + return SleepABitAsync().Then([res]() { return res; }); + }); +} + template class TrackingGenerator { public: diff --git a/cpp/src/arrow/util/CMakeLists.txt b/cpp/src/arrow/util/CMakeLists.txt index 1983819f445..8662c319c99 100644 --- a/cpp/src/arrow/util/CMakeLists.txt +++ b/cpp/src/arrow/util/CMakeLists.txt @@ -79,6 +79,7 @@ add_arrow_test(threading-utility-test counting_semaphore_test.cc future_test.cc task_group_test.cc + test_common.cc thread_pool_test.cc) add_arrow_benchmark(bit_block_counter_benchmark) diff --git a/cpp/src/arrow/util/async_generator_test.cc b/cpp/src/arrow/util/async_generator_test.cc index 65af7e8eae9..724ad6651eb 100644 --- a/cpp/src/arrow/util/async_generator_test.cc +++ b/cpp/src/arrow/util/async_generator_test.cc @@ -37,30 +37,6 @@ namespace arrow { -template -AsyncGenerator AsyncVectorIt(std::vector v) { - return MakeVectorGenerator(std::move(v)); -} - -template -AsyncGenerator FailsAt(AsyncGenerator src, int failing_index) { - auto index = std::make_shared>(0); - return [src, index, failing_index]() { - auto idx = index->fetch_add(1); - if (idx >= failing_index) { - return Future::MakeFinished(Status::Invalid("XYZ")); - } - return src(); - }; -} - -template -AsyncGenerator SlowdownABit(AsyncGenerator source) { - return MakeMappedGenerator(std::move(source), [](const T& res) { - return SleepABitAsync().Then([res]() { return res; }); - }); -} - template AsyncGenerator MakeJittery(AsyncGenerator source) { auto latency_generator = arrow::io::LatencyGenerator::Make(0.01); @@ -217,9 +193,9 @@ class GeneratorTestFixture : public ::testing::TestWithParam { protected: AsyncGenerator MakeSource(const std::vector& items) { std::vector wrapped(items.begin(), items.end()); - auto gen = AsyncVectorIt(std::move(wrapped)); + auto gen = util::AsyncVectorIt(std::move(wrapped)); if (IsSlow()) { - return SlowdownABit(std::move(gen)); + return util::SlowdownABit(std::move(gen)); } return gen; } @@ -231,7 +207,7 @@ class GeneratorTestFixture : public ::testing::TestWithParam { return Future::MakeFinished(Status::Invalid("XYZ")); }; if (IsSlow()) { - return SlowdownABit(std::move(gen)); + return util::SlowdownABit(std::move(gen)); } return gen; } @@ -324,7 +300,7 @@ class ManualGenerator { }; TEST(TestAsyncUtil, Visit) { - auto generator = AsyncVectorIt({1, 2, 3}); + auto generator = util::AsyncVectorIt({1, 2, 3}); unsigned int sum = 0; auto sum_future = VisitAsyncGenerator(generator, [&sum](TestInt item) { sum += item.value; @@ -336,7 +312,7 @@ TEST(TestAsyncUtil, Visit) { TEST(TestAsyncUtil, Collect) { std::vector expected = {1, 2, 3}; - auto generator = AsyncVectorIt(expected); + auto generator = util::AsyncVectorIt(expected); auto collected = CollectAsyncGenerator(generator); ASSERT_FINISHES_OK_AND_ASSIGN(auto collected_val, collected); ASSERT_EQ(expected, collected_val); @@ -344,7 +320,7 @@ TEST(TestAsyncUtil, Collect) { TEST(TestAsyncUtil, Map) { std::vector input = {1, 2, 3}; - auto generator = AsyncVectorIt(input); + auto generator = util::AsyncVectorIt(input); std::function mapper = [](const TestInt& in) { return std::to_string(in.value); }; @@ -355,7 +331,7 @@ TEST(TestAsyncUtil, Map) { TEST(TestAsyncUtil, MapAsync) { std::vector input = {1, 2, 3}; - auto generator = AsyncVectorIt(input); + auto generator = util::AsyncVectorIt(input); std::function(const TestInt&)> mapper = [](const TestInt& in) { return SleepAsync(1e-3).Then([in]() { return TestStr(std::to_string(in.value)); }); }; @@ -366,7 +342,7 @@ TEST(TestAsyncUtil, MapAsync) { TEST(TestAsyncUtil, MapReentrant) { std::vector input = {1, 2}; - auto source = AsyncVectorIt(input); + auto source = util::AsyncVectorIt(input); util::TrackingGenerator tracker(std::move(source)); source = MakeTransferredGenerator(AsyncGenerator(tracker), internal::GetCpuThreadPool()); @@ -408,7 +384,7 @@ TEST(TestAsyncUtil, MapParallelStress) { constexpr int NITEMS = 10; for (int i = 0; i < NTASKS; i++) { auto gen = MakeVectorGenerator(RangeVector(NITEMS)); - gen = SlowdownABit(std::move(gen)); + gen = util::SlowdownABit(std::move(gen)); auto guard = ExpectNotAccessedReentrantly(&gen); std::function mapper = [](const TestInt& in) { SleepABit(); @@ -427,9 +403,9 @@ TEST(TestAsyncUtil, MapQueuingFailStress) { for (bool slow : {true, false}) { for (int i = 0; i < NTASKS; i++) { std::shared_ptr> done = std::make_shared>(); - auto inner = AsyncVectorIt(RangeVector(NITEMS)); + auto inner = util::AsyncVectorIt(RangeVector(NITEMS)); if (slow) inner = MakeJittery(inner); - auto gen = FailsAt(inner, NITEMS / 2); + auto gen = util::FailAt(inner, NITEMS / 2); std::function mapper = [done](const TestInt& in) { if (done->load()) { ADD_FAILURE() << "Callback called after generator sent end signal"; @@ -446,7 +422,7 @@ TEST(TestAsyncUtil, MapQueuingFailStress) { TEST(TestAsyncUtil, MapTaskFail) { std::vector input = {1, 2, 3}; - auto generator = AsyncVectorIt(input); + auto generator = util::AsyncVectorIt(input); std::function(const TestInt&)> mapper = [](const TestInt& in) -> Result { if (in.value == 2) { @@ -492,7 +468,7 @@ TEST(TestAsyncUtil, MapTaskDelayedFail) { TEST(TestAsyncUtil, MapSourceFail) { std::vector input = {1, 2, 3}; - auto generator = FailsAt(AsyncVectorIt(input), 1); + auto generator = util::FailAt(util::AsyncVectorIt(input), 1); std::function(const TestInt&)> mapper = [](const TestInt& in) -> Result { return TestStr(std::to_string(in.value)); @@ -505,8 +481,8 @@ TEST(TestAsyncUtil, Concatenated) { std::vector inputOne{1, 2, 3}; std::vector inputTwo{4, 5, 6}; std::vector expected{1, 2, 3, 4, 5, 6}; - auto gen = AsyncVectorIt>( - {AsyncVectorIt(inputOne), AsyncVectorIt(inputTwo)}); + auto gen = util::AsyncVectorIt>( + {util::AsyncVectorIt(inputOne), util::AsyncVectorIt(inputTwo)}); auto concat = MakeConcatenatedGenerator(gen); AssertAsyncGeneratorMatch(expected, concat); } @@ -523,7 +499,7 @@ TEST_P(FromFutureFixture, Basic) { auto to_gen = source.Then([slow](const std::vector& vec) { auto vec_gen = MakeVectorGenerator(vec); if (slow) { - return SlowdownABit(std::move(vec_gen)); + return util::SlowdownABit(std::move(vec_gen)); } return vec_gen; }); @@ -538,7 +514,7 @@ INSTANTIATE_TEST_SUITE_P(FromFutureTests, FromFutureFixture, class MergedGeneratorTestFixture : public GeneratorTestFixture {}; TEST_P(MergedGeneratorTestFixture, Merged) { - auto gen = AsyncVectorIt>( + auto gen = util::AsyncVectorIt>( {MakeSource({1, 2, 3}), MakeSource({4, 5, 6})}); auto concat_gen = MakeMergedGenerator(gen, 10); @@ -552,9 +528,9 @@ TEST_P(MergedGeneratorTestFixture, Merged) { } TEST_P(MergedGeneratorTestFixture, OuterSubscriptionEmpty) { - auto gen = AsyncVectorIt>({}); + auto gen = util::AsyncVectorIt>({}); if (IsSlow()) { - gen = SlowdownABit(gen); + gen = util::SlowdownABit(gen); } auto merged_gen = MakeMergedGenerator(gen, 10); ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, @@ -563,8 +539,9 @@ TEST_P(MergedGeneratorTestFixture, OuterSubscriptionEmpty) { } TEST_P(MergedGeneratorTestFixture, MergedInnerFail) { - auto gen = AsyncVectorIt>( - {MakeSource({1, 2, 3}), FailsAt(MakeSource({1, 2, 3}), 1), MakeSource({1, 2, 3})}); + auto gen = util::AsyncVectorIt>( + {MakeSource({1, 2, 3}), util::FailAt(MakeSource({1, 2, 3}), 1), + MakeSource({1, 2, 3})}); auto merged_gen = MakeMergedGenerator(gen, 10); // Merged generator can be pulled async-reentrantly and we need to make // sure, if it is, that all futures are marked complete, even if there is an error @@ -672,23 +649,23 @@ TEST_P(MergedGeneratorTestFixture, MergedInnerFailCleanup) { TEST_P(MergedGeneratorTestFixture, FinishesQuickly) { // Testing a source that finishes on the first pull - auto source = AsyncVectorIt>({MakeSource({1})}); + auto source = util::AsyncVectorIt>({MakeSource({1})}); auto merged = MakeMergedGenerator(std::move(source), 10); ASSERT_FINISHES_OK_AND_EQ(TestInt(1), merged()); AssertGeneratorExhausted(merged); } TEST_P(MergedGeneratorTestFixture, MergedOuterFail) { - auto gen = - FailsAt(AsyncVectorIt>( - {MakeSource({1, 2, 3}), MakeSource({1, 2, 3}), MakeSource({1, 2, 3})}), - 1); + auto gen = util::FailAt( + util::AsyncVectorIt>( + {MakeSource({1, 2, 3}), MakeSource({1, 2, 3}), MakeSource({1, 2, 3})}), + 1); auto merged_gen = MakeMergedGenerator(gen, 10); ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(merged_gen)); } TEST_P(MergedGeneratorTestFixture, MergedLimitedSubscriptions) { - auto gen = AsyncVectorIt>( + auto gen = util::AsyncVectorIt>( {MakeSource({1, 2}), MakeSource({3, 4}), MakeSource({5, 6, 7, 8}), MakeSource({9, 10, 11, 12})}); util::TrackingGenerator> tracker(std::move(gen)); @@ -739,7 +716,7 @@ TEST_P(MergedGeneratorTestFixture, MergedStress) { guards.push_back(ExpectNotAccessedReentrantly(&source)); sources.push_back(source); } - AsyncGenerator> source_gen = AsyncVectorIt(sources); + AsyncGenerator> source_gen = util::AsyncVectorIt(sources); auto outer_gaurd = ExpectNotAccessedReentrantly(&source_gen); auto merged = MakeMergedGenerator(source_gen, 4); @@ -756,7 +733,7 @@ TEST_P(MergedGeneratorTestFixture, MergedParallelStress) { for (int j = 0; j < NGENERATORS; j++) { sources.push_back(MakeSource(RangeVector(NITEMS))); } - auto merged = MakeMergedGenerator(AsyncVectorIt(sources), 4); + auto merged = MakeMergedGenerator(util::AsyncVectorIt(sources), 4); merged = MakeReadaheadGenerator(merged, 4); ASSERT_FINISHES_OK_AND_ASSIGN(auto items, CollectAsyncGenerator(merged)); ASSERT_EQ(NITEMS * NGENERATORS, items.size()); @@ -1448,7 +1425,7 @@ TEST(TestAsyncUtil, ReadaheadOneItem) { } TEST(TestAsyncUtil, ReadaheadCopy) { - auto source = AsyncVectorIt(RangeVector(6)); + auto source = util::AsyncVectorIt(RangeVector(6)); auto gen = MakeReadaheadGenerator(std::move(source), 2); for (int i = 0; i < 2; i++) { @@ -1466,7 +1443,7 @@ TEST(TestAsyncUtil, ReadaheadCopy) { } TEST(TestAsyncUtil, ReadaheadMove) { - auto source = AsyncVectorIt(RangeVector(6)); + auto source = util::AsyncVectorIt(RangeVector(6)); auto gen = MakeReadaheadGenerator(std::move(source), 2); for (int i = 0; i < 2; i++) { @@ -1600,7 +1577,7 @@ TEST_P(EnumeratorTestFixture, Empty) { } TEST_P(EnumeratorTestFixture, Error) { - auto source = FailsAt(MakeSource({1, 2, 3}), 1); + auto source = util::FailAt(MakeSource({1, 2, 3}), 1); auto enumerated = MakeEnumeratedGenerator(std::move(source)); // Even though the first item finishes ok the enumerator buffers it. The error then @@ -1666,7 +1643,7 @@ class PauseableTestFixture : public GeneratorTestFixture { AsyncGenerator GetSource() { const auto& source = static_cast>(generator_); if (IsSlow()) { - return SlowdownABit(source); + return util::SlowdownABit(source); } else { return source; } @@ -1741,7 +1718,7 @@ TEST_P(SequencerTestFixture, SequenceLambda) { TEST_P(SequencerTestFixture, SequenceError) { { auto original = MakeSource({6, 4, 2}); - original = FailsAt(original, 1); + original = util::FailAt(original, 1); auto sequenced = MakeSequencingGenerator(original, cmp_, is_next_, TestInt(0)); auto collected = CollectAsyncGenerator(sequenced); ASSERT_FINISHES_AND_RAISES(Invalid, collected); @@ -1824,7 +1801,7 @@ INSTANTIATE_TEST_SUITE_P(SequencerTests, SequencerTestFixture, ::testing::Values(false, true)); TEST(TestAsyncIteratorTransform, SkipSome) { - auto original = AsyncVectorIt({1, 2, 3}); + auto original = util::AsyncVectorIt({1, 2, 3}); auto filter = MakeFilter([](TestInt& t) { return t.value != 2; }); auto filtered = MakeTransformedGenerator(std::move(original), filter); AssertAsyncGeneratorMatch({"1", "3"}, std::move(filtered)); diff --git a/cpp/src/arrow/util/thread_pool.cc b/cpp/src/arrow/util/thread_pool.cc index a1387947e3a..e41a8f00e65 100644 --- a/cpp/src/arrow/util/thread_pool.cc +++ b/cpp/src/arrow/util/thread_pool.cc @@ -49,12 +49,24 @@ struct SerialExecutor::State { std::deque task_queue; std::mutex mutex; std::condition_variable wait_for_tasks; + bool paused{false}; bool finished{false}; }; SerialExecutor::SerialExecutor() : state_(std::make_shared()) {} -SerialExecutor::~SerialExecutor() = default; +SerialExecutor::~SerialExecutor() { + auto state = state_; + std::unique_lock lk(state->mutex); + if (!state->task_queue.empty()) { + // We may have remaining tasks if the executor is being abandoned. We could have + // resource leakage in this case. However, we can force the cleanup to happen now + state->paused = false; + lk.unlock(); + RunLoop(); + lk.lock(); + } +} Status SerialExecutor::SpawnReal(TaskHints hints, FnOnce task, StopToken stop_token, StopCallback&& stop_callback) { @@ -68,6 +80,11 @@ Status SerialExecutor::SpawnReal(TaskHints hints, FnOnce task, auto state = state_; { std::lock_guard lk(state->mutex); + if (state_->finished) { + return Status::Invalid( + "Attempt to schedule a task on a serial executor that has already finished or " + "been abandoned"); + } state->task_queue.push_back( Task{std::move(task), std::move(stop_token), std::move(stop_callback)}); } @@ -75,8 +92,17 @@ Status SerialExecutor::SpawnReal(TaskHints hints, FnOnce task, return Status::OK(); } -void SerialExecutor::MarkFinished() { +void SerialExecutor::Pause() { // Same comment as SpawnReal above + auto state = state_; + { + std::lock_guard lk(state->mutex); + state->paused = true; + } + state->wait_for_tasks.notify_one(); +} + +void SerialExecutor::Finish() { auto state = state_; { std::lock_guard lk(state->mutex); @@ -85,13 +111,32 @@ void SerialExecutor::MarkFinished() { state->wait_for_tasks.notify_one(); } +bool SerialExecutor::IsFinished() { + std::lock_guard lk(state_->mutex); + return state_->finished; +} + +void SerialExecutor::Unpause() { + auto state = state_; + { + std::lock_guard lk(state->mutex); + state->paused = false; + } +} + void SerialExecutor::RunLoop() { // This is called from the SerialExecutor's main thread, so the // state is guaranteed to be kept alive. std::unique_lock lk(state_->mutex); - while (!state_->finished) { - while (!state_->task_queue.empty()) { + // If paused we break out immediately. If finished we only break out + // when all work is done. + while (!state_->paused && !(state_->finished && state_->task_queue.empty())) { + // The inner loop is to check if we need to sleep (e.g. while waiting on some + // async task to finish from another thread pool). We still need to check paused + // because sometimes we will pause even with work leftover when processing + // an async generator + while (!state_->paused && !state_->task_queue.empty()) { Task task = std::move(state_->task_queue.front()); state_->task_queue.pop_front(); lk.unlock(); @@ -108,8 +153,9 @@ void SerialExecutor::RunLoop() { } // In this case we must be waiting on work from external (e.g. I/O) executors. Wait // for tasks to arrive (typically via transferred futures). - state_->wait_for_tasks.wait( - lk, [&] { return state_->finished || !state_->task_queue.empty(); }); + state_->wait_for_tasks.wait(lk, [&] { + return state_->paused || state_->finished || !state_->task_queue.empty(); + }); } } diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index a104e0e3590..4b7002a6736 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -36,6 +36,7 @@ #include "arrow/util/cancel.h" #include "arrow/util/functional.h" #include "arrow/util/future.h" +#include "arrow/util/iterator.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" @@ -276,6 +277,82 @@ class ARROW_EXPORT SerialExecutor : public Executor { return FutureToSync(fut); } + /// \brief Transform an AsyncGenerator into an Iterator + /// + /// An event loop will be created and each call to Next will power the event loop with + /// the calling thread until the next item is ready to be delivered. + /// + /// Note: The iterator's destructor will run until the given generator is fully + /// exhausted. If you wish to abandon iteration before completion then the correct + /// approach is to use a stop token to cause the generator to exhaust early. + template + static Iterator IterateGenerator( + internal::FnOnce()>>(Executor*)> initial_task) { + auto serial_executor = std::unique_ptr(new SerialExecutor()); + auto maybe_generator = std::move(initial_task)(serial_executor.get()); + if (!maybe_generator.ok()) { + return MakeErrorIterator(maybe_generator.status()); + } + auto generator = maybe_generator.MoveValueUnsafe(); + struct SerialIterator { + SerialIterator(std::unique_ptr executor, + std::function()> generator) + : executor(std::move(executor)), generator(std::move(generator)) {} + ARROW_DISALLOW_COPY_AND_ASSIGN(SerialIterator); + ARROW_DEFAULT_MOVE_AND_ASSIGN(SerialIterator); + ~SerialIterator() { + // A serial iterator must be consumed before it can be destroyed. Allowing it to + // do otherwise would lead to resource leakage. There will likely be deadlocks at + // this spot in the future but these will be the result of other bugs and not the + // fact that we are forcing consumption here. + + // If a streaming API needs to support early abandonment then it should be done so + // with a cancellation token and not simply discarding the iterator and expecting + // the underlying work to clean up correctly. + if (executor && !executor->IsFinished()) { + while (true) { + Result maybe_next = Next(); + if (!maybe_next.ok() || IsIterationEnd(*maybe_next)) { + break; + } + } + } + } + + Result Next() { + executor->Unpause(); + // This call may lead to tasks being scheduled in the serial executor + Future next_fut = generator(); + next_fut.AddCallback([this](const Result& res) { + // If we're done iterating we should drain the rest of the tasks in the executor + if (!res.ok() || IsIterationEnd(*res)) { + executor->Finish(); + return; + } + // Otherwise we will break out immediately, leaving the remaining tasks for + // the next call. + executor->Pause(); + }); + // Borrow this thread and run tasks until the future is finished + executor->RunLoop(); + if (!next_fut.is_finished()) { + // Not clear this is possible since RunLoop wouldn't generally exit + // unless we paused/finished which would imply next_fut has been + // finished. + return Status::Invalid( + "Serial executor terminated before next result computed"); + } + // At this point we may still have tasks in the executor, that is ok. + // We will run those tasks the next time through. + return next_fut.result(); + } + + std::unique_ptr executor; + std::function()> generator; + }; + return Iterator(SerialIterator{std::move(serial_executor), std::move(generator)}); + } + private: SerialExecutor(); @@ -283,18 +360,25 @@ class ARROW_EXPORT SerialExecutor : public Executor { struct State; std::shared_ptr state_; + void RunLoop(); + // We mark the serial executor "finished" when there should be + // no more tasks scheduled on it. It's not strictly needed but + // can help catch bugs where we are trying to use the executor + // after we are done with it. + void Finish(); + bool IsFinished(); + // We pause the executor when we are running an async generator + // and we have received an item that we can deliver. + void Pause(); + void Unpause(); + template ::SyncType> Future Run(TopLevelTask initial_task) { auto final_fut = std::move(initial_task)(this); - if (final_fut.is_finished()) { - return final_fut; - } - final_fut.AddCallback([this](const FTSync&) { MarkFinished(); }); + final_fut.AddCallback([this](const FTSync&) { Finish(); }); RunLoop(); return final_fut; } - void RunLoop(); - void MarkFinished(); }; /// An Executor implementation spawning tasks in FIFO manner on a fixed-size diff --git a/cpp/src/arrow/util/thread_pool_test.cc b/cpp/src/arrow/util/thread_pool_test.cc index 40a544c8829..047f3634769 100644 --- a/cpp/src/arrow/util/thread_pool_test.cc +++ b/cpp/src/arrow/util/thread_pool_test.cc @@ -33,6 +33,7 @@ #include #include "arrow/status.h" +#include "arrow/testing/async_test_util.h" #include "arrow/testing/executor_util.h" #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" @@ -261,6 +262,141 @@ TEST_P(TestRunSynchronously, PropagatedError) { INSTANTIATE_TEST_SUITE_P(TestRunSynchronously, TestRunSynchronously, ::testing::Values(false, true)); +TEST(SerialExecutor, AsyncGenerator) { + std::vector values{1, 2, 3, 4, 5}; + auto source = util::SlowdownABit(util::AsyncVectorIt(values)); + Iterator iter = + SerialExecutor::IterateGenerator([&source](Executor* executor) { + return MakeMappedGenerator(source, [executor](const TestInt& ti) { + return DeferNotOk(executor->Submit([ti] { return ti; })); + }); + }); + ASSERT_OK_AND_ASSIGN(auto vec, iter.ToVector()); + ASSERT_EQ(vec, values); +} + +TEST(SerialExecutor, AsyncGeneratorWithFollowUp) { + // Sometimes a task will generate follow-up tasks. These should be run + // before the next task is started + bool follow_up_ran = false; + bool first = true; + Iterator iter = + SerialExecutor::IterateGenerator([&](Executor* executor) { + return [=, &first, &follow_up_ran]() -> Future { + if (first) { + first = false; + Future item = + DeferNotOk(executor->Submit([] { return TestInt(0); })); + RETURN_NOT_OK(executor->Spawn([&] { follow_up_ran = true; })); + return item; + } + return DeferNotOk(executor->Submit([] { return IterationEnd(); })); + }; + }); + ASSERT_FALSE(follow_up_ran); + ASSERT_OK_AND_EQ(TestInt(0), iter.Next()); + ASSERT_FALSE(follow_up_ran); + ASSERT_OK_AND_EQ(IterationEnd(), iter.Next()); + ASSERT_TRUE(follow_up_ran); +} + +TEST(SerialExecutor, AsyncGeneratorWithAsyncFollowUp) { + // Simulates a situation where a user calls into the async generator, tasks (e.g. I/O + // readahead tasks) are spawned onto the I/O threadpool, the user gets a result, and + // then the I/O readahead tasks are completed while there is no calling thread in the + // async generator to hand the task off to (it should be queued up) + bool follow_up_ran = false; + bool first = true; + Executor* captured_executor = nullptr; + Iterator iter = + SerialExecutor::IterateGenerator([&](Executor* executor) { + return [=, &first, &captured_executor]() -> Future { + if (first) { + captured_executor = executor; + first = false; + return DeferNotOk(executor->Submit([] { + // I/O tasks would be scheduled at this point + return TestInt(0); + })); + } + return DeferNotOk(executor->Submit([] { return IterationEnd(); })); + }; + }); + ASSERT_FALSE(follow_up_ran); + ASSERT_OK_AND_EQ(TestInt(0), iter.Next()); + // I/O task completes and has reference to executor to submit continuation + ASSERT_OK(captured_executor->Spawn([&] { follow_up_ran = true; })); + // Follow-up task can't run right now because there is no thread in the executor + SleepABit(); + ASSERT_FALSE(follow_up_ran); + // Follow-up should run as part of retrieving the next item + ASSERT_OK_AND_EQ(IterationEnd(), iter.Next()); + ASSERT_TRUE(follow_up_ran); +} + +TEST(SerialExecutor, AsyncGeneratorWithCleanup) { + // Test the case where tasks are added to the executor after the task that + // marks the final future complete (i.e. the terminal item). These tasks + // must run before the terminal item is delivered from the iterator. + bool follow_up_ran = false; + Iterator iter = + SerialExecutor::IterateGenerator([&](Executor* executor) { + return [=, &follow_up_ran]() -> Future { + Future end = + DeferNotOk(executor->Submit([] { return IterationEnd(); })); + RETURN_NOT_OK(executor->Spawn([&] { follow_up_ran = true; })); + return end; + }; + }); + ASSERT_FALSE(follow_up_ran); + ASSERT_OK_AND_EQ(IterationEnd(), iter.Next()); + ASSERT_TRUE(follow_up_ran); +} + +TEST(SerialExecutor, AbandonIteratorWithCleanup) { + // If we abandon an iterator we still need to drain all remaining tasks + bool follow_up_ran = false; + bool first = true; + { + Iterator iter = + SerialExecutor::IterateGenerator([&](Executor* executor) { + return [=, &first, &follow_up_ran]() -> Future { + if (first) { + first = false; + Future item = + DeferNotOk(executor->Submit([] { return TestInt(0); })); + RETURN_NOT_OK(executor->Spawn([&] { follow_up_ran = true; })); + return item; + } + return DeferNotOk(executor->Submit([] { return IterationEnd(); })); + }; + }); + ASSERT_FALSE(follow_up_ran); + ASSERT_OK_AND_EQ(TestInt(0), iter.Next()); + // At this point the iterator still has one remaining cleanup task + ASSERT_FALSE(follow_up_ran); + } + ASSERT_TRUE(follow_up_ran); +} + +TEST(SerialExecutor, FailingIteratorWithCleanup) { + // If an iterator hits an error we should still generally run any remaining tasks as + // they might be cleanup tasks. + bool follow_up_ran = false; + Iterator iter = + SerialExecutor::IterateGenerator([&](Executor* executor) { + return [=, &follow_up_ran]() -> Future { + Future end = DeferNotOk(executor->Submit( + []() -> Result { return Status::Invalid("XYZ"); })); + RETURN_NOT_OK(executor->Spawn([&] { follow_up_ran = true; })); + return end; + }; + }); + ASSERT_FALSE(follow_up_ran); + ASSERT_RAISES(Invalid, iter.Next()); + ASSERT_TRUE(follow_up_ran); +} + class TransferTest : public testing::Test { public: internal::Executor* executor() { return mock_executor.get(); }