diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index a815b092800..57c6d9b5dde 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -748,21 +748,31 @@ class ReadaheadGenerator { auto state = state_; return fut.Then( [state](const T& result) -> Future { - state->MarkFinishedIfDone(result); - if (state->finished.load()) { - if (state->num_running.fetch_sub(1) == 1) { - state->final_future.MarkFinished(); + bool mark_finished = false; + { + auto guard = state->mutex.Lock(); + state->MarkFinishedIfDone(result); + --state->num_running; + if (state->finished) { + mark_finished = state->num_running == 0; } - } else { - state->num_running.fetch_sub(1); + } + if (mark_finished) { + state->final_future.MarkFinished(); } return result; }, [state](const Status& err) -> Future { // If there is an error we need to make sure all running // tasks finish before we return the error. - state->finished.store(true); - if (state->num_running.fetch_sub(1) == 1) { + bool mark_finished = false; + { + auto guard = state->mutex.Lock(); + state->finished = true; + --state->num_running; + mark_finished = state->num_running == 0; + } + if (mark_finished) { state->final_future.MarkFinished(); } return state->final_future.Then([err]() -> Result { return err; }); @@ -772,7 +782,12 @@ class ReadaheadGenerator { Future operator()() { if (state_->readahead_queue.empty()) { // This is the first request, let's pump the underlying queue - state_->num_running.store(state_->max_readahead); + { + auto guard = state_->mutex.Lock(); + // We're going to push to the queue below, but we need + // to update `num_running` while we're holding the lock. + state_->num_running = state_->max_readahead; + } for (int i = 0; i < state_->max_readahead; i++) { auto next = state_->source_generator(); auto next_after_check = AddMarkFinishedContinuation(std::move(next)); @@ -780,12 +795,21 @@ class ReadaheadGenerator { } } // Pop one and add one - auto result = state_->readahead_queue.front(); + auto result = std::move(state_->readahead_queue.front()); state_->readahead_queue.pop(); - if (state_->finished.load()) { + bool is_finished = false; + { + auto guard = state_->mutex.Lock(); + is_finished = state_->finished; + if (!is_finished) { + // We're going to push to the queue below, but we need + // to update `num_running` while we're holding the lock. + ++state_->num_running; + } + } + if (is_finished) { state_->readahead_queue.push(AsyncGeneratorEnd()); } else { - state_->num_running.fetch_add(1); auto back_of_queue = state_->source_generator(); auto back_of_queue_after_check = AddMarkFinishedContinuation(std::move(back_of_queue)); @@ -800,16 +824,18 @@ class ReadaheadGenerator { : source_generator(std::move(source_generator)), max_readahead(max_readahead) {} void MarkFinishedIfDone(const T& next_result) { + // ASSERT_HELD(mutex) if (IsIterationEnd(next_result)) { - finished.store(true); + finished = true; } } AsyncGenerator source_generator; int max_readahead; Future<> final_future = Future<>::Make(); - std::atomic num_running{0}; - std::atomic finished{false}; + int num_running{0}; // GUARDED_BY(mutex) + bool finished{false}; // GUARDED_BY(mutex) + arrow::util::Mutex mutex; std::queue> readahead_queue; };