diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index f034cea9983..0f3e9205f0c 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -1139,9 +1139,8 @@ class BackgroundGenerator { public: explicit BackgroundGenerator(Iterator it, internal::Executor* io_executor, int max_q, int q_restart) - : state_(std::make_shared(io_executor, std::move(it), max_q, q_restart)) {} - - ~BackgroundGenerator() {} + : state_(std::make_shared(io_executor, std::move(it), max_q, q_restart)), + cleanup_(std::make_shared(state_.get())) {} Future operator()() { auto guard = state_->mutex.Lock(); @@ -1156,16 +1155,14 @@ class BackgroundGenerator { } else { auto next = Future::MakeFinished(std::move(state_->queue.front())); state_->queue.pop(); - if (!state_->running && - static_cast(state_->queue.size()) <= state_->q_restart) { - state_->RestartTask(state_, std::move(guard)); + if (state_->NeedsRestart()) { + return state_->RestartTask(state_, std::move(guard), std::move(next)); } return next; } - if (!state_->running) { - // This branch should only be needed to start the background thread on the first - // call - state_->RestartTask(state_, std::move(guard)); + // This should only trigger the very first time this method is called + if (state_->NeedsRestart()) { + return state_->RestartTask(state_, std::move(guard), std::move(waiting_future)); } return waiting_future; } @@ -1174,11 +1171,12 @@ class BackgroundGenerator { struct State { State(internal::Executor* io_executor, Iterator it, int max_q, int q_restart) : io_executor(io_executor), + max_q(max_q), + q_restart(q_restart), it(std::move(it)), - running(false), + reading(false), finished(false), - max_q(max_q), - q_restart(q_restart) {} + should_shutdown(false) {} void ClearQueue() { while (!queue.empty()) { @@ -1186,78 +1184,165 @@ class BackgroundGenerator { } } - void RestartTask(std::shared_ptr state, util::Mutex::Guard guard) { - if (!finished) { - running = true; - auto spawn_status = io_executor->Spawn([state]() { Task()(std::move(state)); }); - if (!spawn_status.ok()) { - running = false; - finished = true; - if (waiting_future.has_value()) { - auto to_deliver = std::move(waiting_future.value()); - waiting_future.reset(); - guard.Unlock(); - to_deliver.MarkFinished(spawn_status); - } else { - ClearQueue(); - queue.push(spawn_status); - } + bool TaskIsRunning() const { return task_finished.is_valid(); } + + bool NeedsRestart() const { + return !finished && !reading && static_cast(queue.size()) <= q_restart; + } + + void DoRestartTask(std::shared_ptr state, util::Mutex::Guard guard) { + // If we get here we are actually going to start a new task so let's create a + // task_finished future for it + state->task_finished = Future<>::Make(); + state->reading = true; + auto spawn_status = io_executor->Spawn( + [state]() { BackgroundGenerator::WorkerTask(std::move(state)); }); + if (!spawn_status.ok()) { + // If we can't spawn a new task then send an error to the consumer (either via a + // waiting future or the queue) and mark ourselves finished + state->finished = true; + state->task_finished = Future<>(); + if (waiting_future.has_value()) { + auto to_deliver = std::move(waiting_future.value()); + waiting_future.reset(); + guard.Unlock(); + to_deliver.MarkFinished(spawn_status); + } else { + ClearQueue(); + queue.push(spawn_status); } } } + Future RestartTask(std::shared_ptr state, util::Mutex::Guard guard, + Future next) { + if (TaskIsRunning()) { + // If the task is still cleaning up we need to wait for it to finish before + // restarting. We also want to block the consumer until we've restarted the + // reader to avoid multiple restarts + return task_finished.Then([state, next](...) { + // This may appear dangerous (recursive mutex) but we should be guaranteed the + // outer guard has been released by this point. We know... + // * task_finished is not already finished (it would be invalid in that case) + // * task_finished will not be marked complete until we've given up the mutex + auto guard_ = state->mutex.Lock(); + state->DoRestartTask(state, std::move(guard_)); + return next; + }); + } + // Otherwise we can restart immediately + DoRestartTask(std::move(state), std::move(guard)); + return next; + } + internal::Executor* io_executor; + const int max_q; + const int q_restart; Iterator it; - bool running; + + // If true, the task is actively pumping items from the queue and does not need a + // restart + bool reading; + // Set to true when a terminal item arrives bool finished; - int max_q; - int q_restart; + // Signal to the background task to end early because consumers have given up on it + bool should_shutdown; + // If the queue is empty then the consumer will create a waiting future and wait for + // it std::queue> queue; util::optional> waiting_future; + // Every background task is given a future to complete when it is entirely finished + // processing and ready for the next task to start or for State to be destroyed + Future<> task_finished; util::Mutex mutex; }; - class Task { - public: - void operator()(std::shared_ptr state) { - // while condition can't be based on state_ because it is run outside the mutex - bool running = true; - while (running) { - auto next = state->it.Next(); - // Need to capture state->waiting_future inside the mutex to mark finished outside - Future waiting_future; - { - auto guard = state->mutex.Lock(); + // Cleanup task that will be run when all consumer references to the generator are lost + struct Cleanup { + explicit Cleanup(State* state) : state(state) {} + ~Cleanup() { + Future<> finish_fut; + { + auto lock = state->mutex.Lock(); + if (!state->TaskIsRunning()) { + return; + } + // Signal the current task to stop and wait for it to finish + state->should_shutdown = true; + finish_fut = state->task_finished; + } + // Using future as a condition variable here + Status st = finish_fut.status(); + ARROW_UNUSED(st); + } + State* state; + }; - if (!next.ok() || IsIterationEnd(*next)) { - state->finished = true; - state->running = false; - if (!next.ok()) { - state->ClearQueue(); - } - } - if (state->waiting_future.has_value()) { - waiting_future = std::move(state->waiting_future.value()); - state->waiting_future.reset(); - } else { - state->queue.push(std::move(next)); - if (static_cast(state->queue.size()) >= state->max_q) { - state->running = false; - } + static void WorkerTask(std::shared_ptr state) { + // We need to capture the state to read while outside the mutex + bool reading = true; + while (reading) { + auto next = state->it.Next(); + // Need to capture state->waiting_future inside the mutex to mark finished outside + Future waiting_future; + { + auto guard = state->mutex.Lock(); + + if (state->should_shutdown) { + state->finished = true; + break; + } + + if (!next.ok() || IsIterationEnd(*next)) { + // Terminal item. Mark finished to true, send this last item, and quit + state->finished = true; + if (!next.ok()) { + state->ClearQueue(); } - running = state->running; } - // This must happen outside the task. Although presumably there is a transferring - // generator on the other end that will quickly transfer any callbacks off of this - // thread so we can continue looping. Still, best not to rely on that - if (waiting_future.is_valid()) { - waiting_future.MarkFinished(next); + // At this point we are going to send an item. Either we will add it to the + // queue or deliver it to a waiting future. + if (state->waiting_future.has_value()) { + waiting_future = std::move(state->waiting_future.value()); + state->waiting_future.reset(); + } else { + state->queue.push(std::move(next)); + // We just filled up the queue so it is time to quit. We may need to notify + // a cleanup task so we transition to Quitting + if (static_cast(state->queue.size()) >= state->max_q) { + state->reading = false; + } } + reading = state->reading && !state->finished; + } + // This should happen outside the mutex. Presumably there is a + // transferring generator on the other end that will quickly transfer any + // callbacks off of this thread so we can continue looping. Still, best not to + // rely on that + if (waiting_future.is_valid()) { + waiting_future.MarkFinished(next); } } - }; + // Once we've sent our last item we can notify any waiters that we are done and so + // either state can be cleaned up or a new background task can be started + Future<> task_finished; + { + auto guard = state->mutex.Lock(); + // After we give up the mutex state can be safely deleted. We will no longer + // reference it. We can safely transition to idle now. + task_finished = state->task_finished; + state->task_finished = Future<>(); + } + task_finished.MarkFinished(); + } std::shared_ptr state_; + // state_ is held by both the generator and the background thread so it won't be cleaned + // up when all consumer references are relinquished. cleanup_ is only held by the + // generator so it will be destructed when the last consumer reference is gone. We use + // this to cleanup / stop the background generator in case the consuming end stops + // listening (e.g. due to a downstream error) + std::shared_ptr cleanup_; }; constexpr int kDefaultBackgroundMaxQ = 32; diff --git a/cpp/src/arrow/util/async_generator_test.cc b/cpp/src/arrow/util/async_generator_test.cc index 51e4f948d38..dfb60ce7a7d 100644 --- a/cpp/src/arrow/util/async_generator_test.cc +++ b/cpp/src/arrow/util/async_generator_test.cc @@ -814,6 +814,65 @@ TEST_P(BackgroundGeneratorTestFixture, StopAndRestart) { AssertGeneratorExhausted(generator); } +struct TrackingIterator { + explicit TrackingIterator(bool slow) + : token(std::make_shared(false)), slow(slow) {} + + Result Next() { + if (slow) { + SleepABit(); + } + return TestInt(0); + } + std::weak_ptr GetWeakTargetRef() { return std::weak_ptr(token); } + + std::shared_ptr token; + bool slow; +}; + +TEST_P(BackgroundGeneratorTestFixture, AbortReading) { + // If there is an error downstream then it is likely the chain will abort and the + // background generator will lose all references and should abandon reading + TrackingIterator source(IsSlow()); + auto tracker = source.GetWeakTargetRef(); + auto iter = Iterator(std::move(source)); + std::shared_ptr> generator; + { + ASSERT_OK_AND_ASSIGN( + auto gen, MakeBackgroundGenerator(std::move(iter), internal::GetCpuThreadPool())); + generator = std::make_shared>(gen); + } + + // Poll one item to start it up + ASSERT_FINISHES_OK_AND_EQ(TestInt(0), (*generator)()); + ASSERT_FALSE(tracker.expired()); + // Remove last reference to generator, should trigger and wait for cleanup + generator.reset(); + // Cleanup should have ensured no more reference to the source. It may take a moment + // to expire because the background thread has to destruct itself + BusyWait(10, [&tracker] { return tracker.expired(); }); +} + +TEST_P(BackgroundGeneratorTestFixture, AbortOnIdleBackground) { + // Tests what happens when the downstream aborts while the background thread is idle + ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1)); + + auto source = PossiblySlowVectorIt(RangeVector(100), IsSlow()); + std::shared_ptr> generator; + { + ASSERT_OK_AND_ASSIGN(auto gen, + MakeBackgroundGenerator(std::move(source), thread_pool.get())); + generator = std::make_shared>(gen); + } + ASSERT_FINISHES_OK_AND_EQ(TestInt(0), (*generator)()); + + // The generator should pretty quickly fill up the queue and idle + BusyWait(10, [&thread_pool] { return thread_pool->GetNumTasks() == 0; }); + + // Now delete the generator and hope we don't deadlock + generator.reset(); +} + struct SlowEmptyIterator { Result Next() { if (called_) { diff --git a/cpp/src/arrow/util/thread_pool.cc b/cpp/src/arrow/util/thread_pool.cc index 873b9335e74..cd523609d27 100644 --- a/cpp/src/arrow/util/thread_pool.cc +++ b/cpp/src/arrow/util/thread_pool.cc @@ -272,6 +272,12 @@ int ThreadPool::GetCapacity() { return state_->desired_capacity_; } +int ThreadPool::GetNumTasks() { + ProtectAgainstFork(); + std::unique_lock lock(state_->mutex_); + return state_->tasks_queued_or_running_; +} + int ThreadPool::GetActualCapacity() { ProtectAgainstFork(); std::unique_lock lock(state_->mutex_); diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index c4d4d1869c6..cd964385c6e 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -264,6 +264,9 @@ class ARROW_EXPORT ThreadPool : public Executor { // match this value. int GetCapacity() override; + // Return the number of tasks either running or in the queue. + int GetNumTasks(); + // Dynamically change the number of worker threads. // // This function always returns immediately.