diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index d0e1be92fc9..a4ba476a11c 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -959,8 +959,10 @@ class GatingTask::Impl : public std::enable_shared_from_this { } Future<> AsyncTask() { + std::lock_guard lk(mx_); num_launched_++; num_running_++; + running_cv_.notify_all(); /// TODO(ARROW-13004) Could maybe implement this check with future chains /// if we check to see if the future has been "consumed" or not num_finished_++; diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index 0948e5537fe..d132198a259 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -60,6 +60,12 @@ namespace arrow { // Readahead operators, and some other operators, may introduce queueing. Any operators // that introduce buffering should detail the amount of buffering they introduce in their // MakeXYZ function comments. +// +// A generator should always be fully consumed before it is destroyed. +// A generator should not mark a future complete with an error status or a terminal value +// until all outstanding futures have completed. Generators that spawn multiple +// concurrent futures may need to hold onto an error while other concurrent futures wrap +// up. template using AsyncGenerator = std::function()>; @@ -750,19 +756,32 @@ class ReadaheadGenerator { Future AddMarkFinishedContinuation(Future fut) { auto state = state_; return fut.Then( - [state](const T& result) -> Result { + [state](const T& result) -> Future { state->MarkFinishedIfDone(result); + if (state->finished.load()) { + if (state->num_running.fetch_sub(1) == 1) { + state->final_future.MarkFinished(); + } + } else { + state->num_running.fetch_sub(1); + } return result; }, - [state](const Status& err) -> Result { + [state](const Status& err) -> Future { + // If there is an error we need to make sure all running + // tasks finish before we return the error. state->finished.store(true); - return err; + if (state->num_running.fetch_sub(1) == 1) { + state->final_future.MarkFinished(); + } + return state->final_future.Then([err]() -> Result { return err; }); }); } Future operator()() { if (state_->readahead_queue.empty()) { // This is the first request, let's pump the underlying queue + state_->num_running.store(state_->max_readahead); for (int i = 0; i < state_->max_readahead; i++) { auto next = state_->source_generator(); auto next_after_check = AddMarkFinishedContinuation(std::move(next)); @@ -775,6 +794,7 @@ class ReadaheadGenerator { if (state_->finished.load()) { state_->readahead_queue.push(AsyncGeneratorEnd()); } else { + state_->num_running.fetch_add(1); auto back_of_queue = state_->source_generator(); auto back_of_queue_after_check = AddMarkFinishedContinuation(std::move(back_of_queue)); @@ -786,9 +806,7 @@ class ReadaheadGenerator { private: struct State { State(AsyncGenerator source_generator, int max_readahead) - : source_generator(std::move(source_generator)), max_readahead(max_readahead) { - finished.store(false); - } + : source_generator(std::move(source_generator)), max_readahead(max_readahead) {} void MarkFinishedIfDone(const T& next_result) { if (IsIterationEnd(next_result)) { @@ -798,7 +816,9 @@ class ReadaheadGenerator { AsyncGenerator source_generator; int max_readahead; - std::atomic finished; + Future<> final_future = Future<>::Make(); + std::atomic num_running{0}; + std::atomic finished{false}; std::queue> readahead_queue; }; @@ -990,39 +1010,140 @@ AsyncGenerator MakeVectorGenerator(std::vector vec) { /// \see MakeMergedGenerator template class MergedGenerator { + // Note, the implementation of this class is quite complex at the moment (PRs to + // simplify are always welcome) + // + // Terminology is borrowed from rxjs. This is a pull based implementation of the + // mergeAll operator. The "outer subscription" refers to the async + // generator that the caller provided when creating this. The outer subscription + // yields generators. + // + // Each of these generators is then subscribed to (up to max_subscriptions) and these + // are referred to as "inner subscriptions". + // + // As soon as we start we try and establish `max_subscriptions` inner subscriptions. For + // each inner subscription we will cache up to 1 value. This means we may have more + // values than we have been asked for. In our example, if a caller asks for one record + // batch we will start scanning `max_subscriptions` different files. For each file we + // will only queue up to 1 batch (so a separate readahead is needed on the file if batch + // readahead is desired). + // + // If the caller is slow we may accumulate ready-to-deliver items. These are stored + // in `delivered_jobs`. + // + // If the caller is very quick we may accumulate requests. These are stored in + // `waiting_jobs`. + // + // It may be helpful to consider an example, in the scanner the outer subscription + // is some kind of asynchronous directory listing. The inner subscription is + // then a scan on a file yielded by the directory listing. + // + // An "outstanding" request is when we have polled either the inner or outer + // subscription but that future hasn't completed yet. + // + // There are three possible "events" that can happen. + // * A caller could request the next future + // * An outer callback occurs when the next subscription is ready (e.g. the directory + // listing has produced a new file) + // * An inner callback occurs when one of the inner subscriptions emits a value (e.g. + // a file scan emits a record batch) + // + // Any time an event happens the logic is broken into two phases. First, we grab the + // lock and modify the shared state. While doing this we figure out what callbacks we + // will need to execute. Then, we give up the lock and execute these callbacks. It is + // important to execute these callbacks without the lock to avoid deadlock. public: explicit MergedGenerator(AsyncGenerator> source, int max_subscriptions) : state_(std::make_shared(std::move(source), max_subscriptions)) {} Future operator()() { + // A caller has requested a future Future waiting_future; std::shared_ptr delivered_job; + bool mark_generator_complete = false; { auto guard = state_->mutex.Lock(); if (!state_->delivered_jobs.empty()) { + // If we have a job sitting around we can deliver it delivered_job = std::move(state_->delivered_jobs.front()); state_->delivered_jobs.pop_front(); - } else if (state_->finished) { - return IterationTraits::End(); + if (state_->IsCompleteUnlocked(guard)) { + // It's possible this waiting job was the only thing left to handle and + // we have now completed the generator. + mark_generator_complete = true; + } else { + // Since we had a job sitting around we also had an inner subscription + // that had paused. We are going to restart this inner subscription and + // so there will be a new outstanding request. + state_->outstanding_requests++; + } + } else if (state_->broken || + (!state_->first && state_->num_running_subscriptions == 0)) { + // If we are broken or exhausted then prepare a terminal item but + // we won't complete it until we've finished. + Result end_res = IterationEnd(); + if (!state_->final_error.ok()) { + end_res = state_->final_error; + state_->final_error = Status::OK(); + } + return state_->all_finished.Then([end_res]() -> Result { return end_res; }); } else { + // Otherwise we just queue the request and it will be completed when one of the + // ongoing inner subscriptions delivers a result waiting_future = Future::Make(); state_->waiting_jobs.push_back(std::make_shared>(waiting_future)); } + if (state_->first) { + // On the first request we are going to try and immediately fill our queue + // of subscriptions. We assume we are going to be able to start them all. + state_->outstanding_requests += + static_cast(state_->active_subscriptions.size()); + state_->num_running_subscriptions += + static_cast(state_->active_subscriptions.size()); + } } + // If we grabbed a finished item from the delivered_jobs queue then we may need + // to mark the generator finished or issue a request for a new item to fill in + // the spot we just vacated. Notice that we issue that request to the same + // subscription that delivered it (deliverer). if (delivered_job) { - // deliverer will be invalid if outer callback encounters an error and delivers a - // failed result - if (delivered_job->deliverer) { + if (mark_generator_complete) { + state_->all_finished.MarkFinished(); + } else { delivered_job->deliverer().AddCallback( InnerCallback{state_, delivered_job->index}); } return std::move(delivered_job->value); } + // On the first call we try and fill up our subscriptions. It's possible the outer + // generator only has a few items and we can't fill up to what we were hoping. In + // that case we have to bail early. if (state_->first) { state_->first = false; - for (std::size_t i = 0; i < state_->active_subscriptions.size(); i++) { - state_->PullSource().AddCallback(OuterCallback{state_, i}); + mark_generator_complete = false; + for (int i = 0; i < static_cast(state_->active_subscriptions.size()); i++) { + state_->PullSource().AddCallback( + OuterCallback{state_, static_cast(i)}); + // If we have to bail early then we need to update the shared state again so + // we need to reacquire the lock. + auto guard = state_->mutex.Lock(); + if (state_->source_exhausted) { + int excess_requests = + static_cast(state_->active_subscriptions.size()) - i - 1; + state_->outstanding_requests -= excess_requests; + state_->num_running_subscriptions -= excess_requests; + if (excess_requests > 0) { + // It's possible that we are completing the generator by reducing the number + // of outstanding requests (e.g. this happens when the outer subscription and + // all inner subscriptions are synchronous) + mark_generator_complete = state_->IsCompleteUnlocked(guard); + } + break; + } + } + if (mark_generator_complete) { + state_->MarkFinishedAndPurge(); } } return waiting_future; @@ -1034,8 +1155,13 @@ class MergedGenerator { std::size_t index_) : deliverer(deliverer_), value(std::move(value_)), index(index_) {} + // The generator that delivered this result, we will request another item + // from this generator once the result is delivered AsyncGenerator deliverer; + // The result we received from the generator Result value; + // The index of the generator (in active_subscriptions) that delivered this + // result. This is used if we need to replace a finished generator. std::size_t index; }; @@ -1047,9 +1173,11 @@ class MergedGenerator { waiting_jobs(), mutex(), first(true), + broken(false), source_exhausted(false), - finished(false), - num_active_subscriptions(max_subscriptions) {} + outstanding_requests(0), + num_running_subscriptions(0), + final_error(Status::OK()) {} Future> PullSource() { // Need to guard access to source() so we don't pull sync-reentrantly which @@ -1058,50 +1186,178 @@ class MergedGenerator { return source(); } + void SignalErrorUnlocked(const util::Mutex::Guard& guard) { + broken = true; + // Empty any results that have arrived but not asked for. + while (!delivered_jobs.empty()) { + delivered_jobs.pop_front(); + } + } + + // This function is called outside the mutex but it will only ever be + // called once + void MarkFinishedAndPurge() { + all_finished.MarkFinished(); + while (!waiting_jobs.empty()) { + waiting_jobs.front()->MarkFinished(IterationEnd()); + waiting_jobs.pop_front(); + } + } + + // This is called outside the mutex but it is only ever called + // once and Future<>::AddCallback is thread-safe + void MarkFinalError(const Status& err, Future maybe_sink) { + if (maybe_sink.is_valid()) { + // Someone is waiting for this error so lets mark it complete when + // all the work is done + all_finished.AddCallback([maybe_sink, err](const Status& status) mutable { + maybe_sink.MarkFinished(err); + }); + } else { + // No one is waiting for this error right now so it will be delivered + // next. + final_error = err; + } + } + + bool IsCompleteUnlocked(const util::Mutex::Guard& guard) { + return outstanding_requests == 0 && + (broken || (source_exhausted && num_running_subscriptions == 0 && + delivered_jobs.empty())); + } + + bool MarkTaskFinishedUnlocked(const util::Mutex::Guard& guard) { + --outstanding_requests; + return IsCompleteUnlocked(guard); + } + + // The outer generator. Each item we pull from this will be its own generator + // and become an inner subscription AsyncGenerator> source; // active_subscriptions and delivered_jobs will be bounded by max_subscriptions std::vector> active_subscriptions; + // Results delivered by the inner subscriptions that weren't yet asked for by the + // caller std::deque> delivered_jobs; // waiting_jobs is unbounded, reentrant pulls (e.g. AddReadahead) will provide the // backpressure std::deque>> waiting_jobs; + // A future that will be marked complete when the terminal item has arrived and all + // outstanding futures have completed. It is used to hold off emission of an error + // until all outstanding work is done. + Future<> all_finished = Future<>::Make(); util::Mutex mutex; + // A flag cleared when the caller firsts asks for a future. Used to start polling. bool first; + // A flag set when an error arrives, prevents us from issuing new requests. + bool broken; + // A flag set when the outer subscription has been exhausted. Prevents us from + // pulling it further (even though it would be generally harmless) and lets us know we + // are finishing up. bool source_exhausted; - bool finished; - int num_active_subscriptions; + // The number of futures that we have requested from either the outer or inner + // subscriptions that have not yet completed. We cannot mark all_finished until this + // reaches 0. This will never be greater than max_subscriptions + int outstanding_requests; + // The number of running subscriptions. We ramp this up to `max_subscriptions` as + // soon as the first item is requested and then it stays at that level (each exhausted + // inner subscription is replaced by a new inner subscription) until the outer + // subscription is exhausted at which point this descends to 0 (and source_exhausted) + // is then set to true. + int num_running_subscriptions; + // If an error arrives, and the caller hasn't asked for that item, we store the error + // here. It is analagous to delivered_jobs but for errors instead of finished + // results. + Status final_error; }; struct InnerCallback { void operator()(const Result& maybe_next_ref) { + // An item has been delivered by one of the inner subscriptions Future next_fut; const Result* maybe_next = &maybe_next_ref; + // When an item is delivered (and the caller has asked for it) we grab the + // next item from the inner subscription. To avoid this behavior leading to an + // infinite loop (this can happen if the caller's callback asks for the next item) + // we use a while loop. while (true) { Future sink; bool sub_finished = maybe_next->ok() && IsIterationEnd(**maybe_next); + bool pull_next_sub = false; + bool was_broken = false; + bool should_mark_gen_complete = false; + bool should_mark_final_error = false; { auto guard = state->mutex.Lock(); - if (state->finished) { - // We've errored out so just ignore this result and don't keep pumping - return; + if (state->broken) { + // We've errored out previously so ignore the result. If anyone was waiting + // for this they will get IterationEnd when we purge + was_broken = true; + } else { + if (!sub_finished) { + // There is a result to deliver. Either we can deliver it now or we will + // queue it up + if (state->waiting_jobs.empty()) { + state->delivered_jobs.push_back(std::make_shared( + state->active_subscriptions[index], *maybe_next, index)); + } else { + sink = std::move(*state->waiting_jobs.front()); + state->waiting_jobs.pop_front(); + } + } + + // If this is the first error then we transition the state to a broken state + if (!maybe_next->ok()) { + should_mark_final_error = true; + state->SignalErrorUnlocked(guard); + } } - if (!sub_finished) { - if (state->waiting_jobs.empty()) { - state->delivered_jobs.push_back(std::make_shared( - state->active_subscriptions[index], *maybe_next, index)); - } else { - sink = std::move(*state->waiting_jobs.front()); - state->waiting_jobs.pop_front(); + + // If we finished this inner subscription then we need to grab a new inner + // subscription to take its spot. If we can't (because we're broken or + // exhausted) then we aren't going to be starting any new futures and so + // the number of running subscriptions drops. + pull_next_sub = sub_finished && !state->source_exhausted && !was_broken; + if (sub_finished && !pull_next_sub) { + state->num_running_subscriptions--; + } + // There are three situations we won't pull again. If an error occurred or we + // are already finished or if no one was waiting for our result and so we queued + // it up. We will decrement outstanding_requests and possibly mark the + // generator completed. + if (state->broken || (!sink.is_valid() && !sub_finished) || + (sub_finished && state->source_exhausted)) { + if (state->MarkTaskFinishedUnlocked(guard)) { + should_mark_gen_complete = true; } } } - if (sub_finished) { + + // Now we have given up the lock and we can take all the actions we decided we + // need to take. + if (should_mark_final_error) { + state->MarkFinalError(maybe_next->status(), std::move(sink)); + } + + if (should_mark_gen_complete) { + state->MarkFinishedAndPurge(); + } + + // An error occurred elsewhere so there is no need to mark any future + // finished (will happen during the purge) or pull from anything + if (was_broken) { + return; + } + + if (pull_next_sub) { + // We pulled an end token so we need to start a new subscription + // in our spot state->PullSource().AddCallback(OuterCallback{state, index}); } else if (sink.is_valid()) { + // We pulled a valid result and there was someone waiting for it + // so lets fetch the next result from our subscription sink.MarkFinished(*maybe_next); - if (!maybe_next->ok()) return; - next_fut = state->active_subscriptions[index](); if (next_fut.TryAddCallback([this]() { return *this; })) { return; @@ -1111,6 +1367,8 @@ class MergedGenerator { maybe_next = &next_fut.result(); continue; } + // else: We pulled a valid result but no one was waiting for it so + // we can just stop. return; } } @@ -1120,43 +1378,45 @@ class MergedGenerator { struct OuterCallback { void operator()(const Result>& maybe_next) { - bool should_purge = false; + // We have been given a new inner subscription bool should_continue = false; + bool should_mark_gen_complete = false; + bool should_deliver_error = false; + bool source_exhausted = maybe_next.ok() && IsIterationEnd(*maybe_next); Future error_sink; { auto guard = state->mutex.Lock(); - if (!maybe_next.ok() || IsIterationEnd(*maybe_next)) { - state->source_exhausted = true; - if (!maybe_next.ok() || --state->num_active_subscriptions == 0) { - state->finished = true; - should_purge = true; - } - if (!maybe_next.ok()) { - if (state->waiting_jobs.empty()) { - state->delivered_jobs.push_back(std::make_shared( - AsyncGenerator(), maybe_next.status(), index)); - } else { + if (!maybe_next.ok() || source_exhausted || state->broken) { + // If here then we will not pull any more from the outer source + if (!state->broken && !maybe_next.ok()) { + state->SignalErrorUnlocked(guard); + // If here then we are the first error so we need to deliver it + should_deliver_error = true; + if (!state->waiting_jobs.empty()) { error_sink = std::move(*state->waiting_jobs.front()); state->waiting_jobs.pop_front(); } } + if (source_exhausted) { + state->source_exhausted = true; + state->num_running_subscriptions--; + } + if (state->MarkTaskFinishedUnlocked(guard)) { + should_mark_gen_complete = true; + } } else { state->active_subscriptions[index] = *maybe_next; should_continue = true; } } - if (error_sink.is_valid()) { - error_sink.MarkFinished(maybe_next.status()); + if (should_deliver_error) { + state->MarkFinalError(maybe_next.status(), std::move(error_sink)); + } + if (should_mark_gen_complete) { + state->MarkFinishedAndPurge(); } if (should_continue) { (*maybe_next)().AddCallback(InnerCallback{state, index}); - } else if (should_purge) { - // At this point state->finished has been marked true so no one else - // will be interacting with waiting_jobs and we can iterate outside lock - while (!state->waiting_jobs.empty()) { - state->waiting_jobs.front()->MarkFinished(IterationTraits::End()); - state->waiting_jobs.pop_front(); - } } } std::shared_ptr state; diff --git a/cpp/src/arrow/util/async_generator_test.cc b/cpp/src/arrow/util/async_generator_test.cc index 7e5fccd9ef1..65af7e8eae9 100644 --- a/cpp/src/arrow/util/async_generator_test.cc +++ b/cpp/src/arrow/util/async_generator_test.cc @@ -551,11 +551,131 @@ TEST_P(MergedGeneratorTestFixture, Merged) { ASSERT_EQ(expected, concat_set); } +TEST_P(MergedGeneratorTestFixture, OuterSubscriptionEmpty) { + auto gen = AsyncVectorIt>({}); + if (IsSlow()) { + gen = SlowdownABit(gen); + } + auto merged_gen = MakeMergedGenerator(gen, 10); + ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, + CollectAsyncGenerator(std::move(merged_gen))); + ASSERT_TRUE(collected.empty()); +} + TEST_P(MergedGeneratorTestFixture, MergedInnerFail) { auto gen = AsyncVectorIt>( - {MakeSource({1, 2, 3}), MakeFailingSource()}); + {MakeSource({1, 2, 3}), FailsAt(MakeSource({1, 2, 3}), 1), MakeSource({1, 2, 3})}); auto merged_gen = MakeMergedGenerator(gen, 10); - ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(merged_gen)); + // Merged generator can be pulled async-reentrantly and we need to make + // sure, if it is, that all futures are marked complete, even if there is an error + std::vector> futures; + for (int i = 0; i < 20; i++) { + futures.push_back(merged_gen()); + } + // Items could come in any order so the only guarantee is that we see at least + // one item before the failure. After the failure the behavior is undefined + // except that we know the futures must complete. + bool error_seen = false; + for (int i = 0; i < 20; i++) { + Future fut = futures[i]; + ASSERT_TRUE(fut.Wait(arrow::kDefaultAssertFinishesWaitSeconds)); + Status status = futures[i].status(); + if (!status.ok()) { + ASSERT_GT(i, 0); + if (!error_seen) { + error_seen = true; + ASSERT_TRUE(status.IsInvalid()); + } + } + } +} + +TEST_P(MergedGeneratorTestFixture, MergedInnerFailCleanup) { + // The purpose of this test is to ensure we do not emit an error until all outstanding + // futures have completed. This is part of the AsyncGenerator contract + std::shared_ptr failing_task_gate = GatingTask::Make(); + std::shared_ptr passing_task_gate = GatingTask::Make(); + // A passing inner source emits one item and then waits on a gate and then + // emits a terminal item. + // + // A failing inner source emits one item and then waits on a gate and then + // emits an error. + auto make_source = [&](bool fails) -> AsyncGenerator { + std::shared_ptr> count = std::make_shared>(0); + return [&, fails, count]() -> Future { + int my_count = (*count)++; + if (my_count == 1) { + if (fails) { + return failing_task_gate->AsyncTask().Then( + []() -> Result { return Status::Invalid("XYZ"); }); + } else { + return passing_task_gate->AsyncTask().Then( + []() -> Result { return IterationEnd(); }); + } + } else { + return SleepABitAsync().Then([] { return TestInt(0); }); + } + }; + }; + auto outer = MakeVectorGenerator>( + {make_source(false), make_source(true), make_source(false)}); + auto merged_gen = MakeMergedGenerator(outer, 10); + + constexpr int NUM_FUTURES = 20; + std::vector> futures; + for (int i = 0; i < NUM_FUTURES; i++) { + futures.push_back(merged_gen()); + } + + auto count_completed_futures = [&] { + int count = 0; + for (const auto& future : futures) { + if (future.is_finished()) { + count++; + } + } + return count; + }; + + // The first future from each source can be emitted. The second from + // each source should be blocked by the gates. + ASSERT_OK(passing_task_gate->WaitForRunning(2)); + ASSERT_OK(failing_task_gate->WaitForRunning(1)); + ASSERT_EQ(count_completed_futures(), 3); + // We will unlock the error now but it should not be emitted because + // the other futures are blocked + // std::cout << "Unlocking failing gate\n"; + ASSERT_OK(failing_task_gate->Unlock()); + SleepABit(); + ASSERT_EQ(count_completed_futures(), 3); + // Now we will unlock the in-progress futures and everything should complete + // We don't know exactly what order things will emit in but after the failure + // we should only see terminal items + // std::cout << "Unlocking passing gate\n"; + ASSERT_OK(passing_task_gate->Unlock()); + + bool error_seen = false; + for (const auto& fut : futures) { + ASSERT_TRUE(fut.Wait(arrow::kDefaultAssertFinishesWaitSeconds)); + if (fut.status().ok()) { + if (error_seen) { + ASSERT_TRUE(IsIterationEnd(*fut.result())); + } + } else { + // We should only see one error + ASSERT_FALSE(error_seen); + error_seen = true; + ASSERT_TRUE(fut.status().IsInvalid()); + } + } +} + +TEST_P(MergedGeneratorTestFixture, FinishesQuickly) { + // Testing a source that finishes on the first pull + auto source = AsyncVectorIt>({MakeSource({1})}); + auto merged = MakeMergedGenerator(std::move(source), 10); + ASSERT_FINISHES_OK_AND_EQ(TestInt(1), merged()); + AssertGeneratorExhausted(merged); } TEST_P(MergedGeneratorTestFixture, MergedOuterFail) { @@ -1310,6 +1430,23 @@ TEST(TestAsyncUtil, Readahead) { ASSERT_TRUE(IsIterationEnd(last_val)); } +TEST(TestAsyncUtil, ReadaheadOneItem) { + bool delivered = false; + auto source = [&delivered]() { + if (!delivered) { + delivered = true; + return Future::MakeFinished(0); + } else { + return Future::MakeFinished(IterationTraits::End()); + } + }; + auto readahead = MakeReadaheadGenerator(source, 10); + auto collected = CollectAsyncGenerator(std::move(readahead)); + ASSERT_FINISHES_OK_AND_ASSIGN(auto actual, collected); + ASSERT_EQ(1, actual.size()); + ASSERT_EQ(TestInt(0), actual[0]); +} + TEST(TestAsyncUtil, ReadaheadCopy) { auto source = AsyncVectorIt(RangeVector(6)); auto gen = MakeReadaheadGenerator(std::move(source), 2); @@ -1376,6 +1513,60 @@ TEST(TestAsyncUtil, ReadaheadFailed) { } } +TEST(TestAsyncUtil, ReadaheadFailedWaitForInFlight) { + ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(20)); + // If a failure causes an early end then we should not emit that failure + // until all in-flight futures have completed. This is to prevent tasks from + // outliving the generator + std::atomic counter(0); + auto failure_gating_task = GatingTask::Make(); + auto in_flight_gating_task = GatingTask::Make(); + auto source = [&]() -> Future { + auto count = counter++; + return DeferNotOk(thread_pool->Submit([&, count]() -> Result { + if (count == 0) { + failure_gating_task->Task()(); + return Status::Invalid("X"); + } + in_flight_gating_task->Task()(); + // These are our in-flight tasks + return TestInt(0); + })); + }; + auto readahead = MakeReadaheadGenerator(source, 10); + auto should_be_invalid = readahead(); + ASSERT_OK(in_flight_gating_task->WaitForRunning(10)); + ASSERT_OK(failure_gating_task->Unlock()); + SleepABit(); + // Can't be finished because in-flight tasks are still running + AssertNotFinished(should_be_invalid); + ASSERT_OK(in_flight_gating_task->Unlock()); + ASSERT_FINISHES_AND_RAISES(Invalid, should_be_invalid); +} + +TEST(TestAsyncUtil, ReadaheadFailedStress) { + constexpr int NTASKS = 10; + ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(20)); + for (int i = 0; i < NTASKS; i++) { + std::atomic counter(0); + std::atomic finished(false); + AsyncGenerator source = [&]() -> Future { + auto count = counter++; + return DeferNotOk(thread_pool->Submit([&, count]() -> Result { + SleepABit(); + if (count == 5) { + return Status::Invalid("X"); + } + // Generator should not have been finished at this point + EXPECT_FALSE(finished); + return TestInt(0); + })); + }; + ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(source)); + finished.store(false); + } +} + class EnumeratorTestFixture : public GeneratorTestFixture { protected: void AssertEnumeratedCorrectly(AsyncGenerator>& gen,