Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 150 additions & 65 deletions cpp/src/arrow/util/async_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1139,9 +1139,8 @@ class BackgroundGenerator {
public:
explicit BackgroundGenerator(Iterator<T> it, internal::Executor* io_executor, int max_q,
int q_restart)
: state_(std::make_shared<State>(io_executor, std::move(it), max_q, q_restart)) {}

~BackgroundGenerator() {}
: state_(std::make_shared<State>(io_executor, std::move(it), max_q, q_restart)),
cleanup_(std::make_shared<Cleanup>(state_.get())) {}

Future<T> operator()() {
auto guard = state_->mutex.Lock();
Expand All @@ -1156,16 +1155,14 @@ class BackgroundGenerator {
} else {
auto next = Future<T>::MakeFinished(std::move(state_->queue.front()));
state_->queue.pop();
if (!state_->running &&
static_cast<int>(state_->queue.size()) <= state_->q_restart) {
state_->RestartTask(state_, std::move(guard));
if (state_->NeedsRestart()) {
return state_->RestartTask(state_, std::move(guard), std::move(next));
}
return next;
}
if (!state_->running) {
// This branch should only be needed to start the background thread on the first
// call
state_->RestartTask(state_, std::move(guard));
// This should only trigger the very first time this method is called
if (state_->NeedsRestart()) {
return state_->RestartTask(state_, std::move(guard), std::move(waiting_future));
}
return waiting_future;
}
Expand All @@ -1174,90 +1171,178 @@ class BackgroundGenerator {
struct State {
State(internal::Executor* io_executor, Iterator<T> it, int max_q, int q_restart)
: io_executor(io_executor),
max_q(max_q),
q_restart(q_restart),
it(std::move(it)),
running(false),
reading(false),
finished(false),
max_q(max_q),
q_restart(q_restart) {}
should_shutdown(false) {}

void ClearQueue() {
while (!queue.empty()) {
queue.pop();
}
}

void RestartTask(std::shared_ptr<State> state, util::Mutex::Guard guard) {
if (!finished) {
running = true;
auto spawn_status = io_executor->Spawn([state]() { Task()(std::move(state)); });
if (!spawn_status.ok()) {
running = false;
finished = true;
if (waiting_future.has_value()) {
auto to_deliver = std::move(waiting_future.value());
waiting_future.reset();
guard.Unlock();
to_deliver.MarkFinished(spawn_status);
} else {
ClearQueue();
queue.push(spawn_status);
}
bool TaskIsRunning() const { return task_finished.is_valid(); }

bool NeedsRestart() const {
return !finished && !reading && static_cast<int>(queue.size()) <= q_restart;
}

void DoRestartTask(std::shared_ptr<State> state, util::Mutex::Guard guard) {
// If we get here we are actually going to start a new task so let's create a
// task_finished future for it
state->task_finished = Future<>::Make();
state->reading = true;
auto spawn_status = io_executor->Spawn(
[state]() { BackgroundGenerator::WorkerTask(std::move(state)); });
if (!spawn_status.ok()) {
// If we can't spawn a new task then send an error to the consumer (either via a
// waiting future or the queue) and mark ourselves finished
state->finished = true;
state->task_finished = Future<>();
if (waiting_future.has_value()) {
auto to_deliver = std::move(waiting_future.value());
waiting_future.reset();
guard.Unlock();
to_deliver.MarkFinished(spawn_status);
} else {
ClearQueue();
queue.push(spawn_status);
}
}
}

Future<T> RestartTask(std::shared_ptr<State> state, util::Mutex::Guard guard,
Future<T> next) {
if (TaskIsRunning()) {
// If the task is still cleaning up we need to wait for it to finish before
// restarting. We also want to block the consumer until we've restarted the
// reader to avoid multiple restarts
return task_finished.Then([state, next](...) {
// This may appear dangerous (recursive mutex) but we should be guaranteed the
// outer guard has been released by this point. We know...
// * task_finished is not already finished (it would be invalid in that case)
// * task_finished will not be marked complete until we've given up the mutex
auto guard_ = state->mutex.Lock();
state->DoRestartTask(state, std::move(guard_));
return next;
});
}
// Otherwise we can restart immediately
DoRestartTask(std::move(state), std::move(guard));
return next;
}

internal::Executor* io_executor;
const int max_q;
const int q_restart;
Iterator<T> it;
bool running;

// If true, the task is actively pumping items from the queue and does not need a
// restart
bool reading;
// Set to true when a terminal item arrives
bool finished;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where complexity starts being difficult to reason about. What's the difference between running_at_all and finished? What are the possible combinations?
You may want to define a enum to describe the current state rather than having three different booleans...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finished means that the background thread is done. Any future reads that would require starting the background generator back up should return an end token.

running_at_all means that not only is it finished but the background thread is done delivering the last item (we mark finished and "reserve a spot" in the mutex but can't deliver the item in the mutex) and is not going to be checking to see if it needs to mark the final future complete.

I will consider an enum and see if the 8 possible states collapse down to some reasonable subset.

int max_q;
int q_restart;
// Signal to the background task to end early because consumers have given up on it
bool should_shutdown;
// If the queue is empty then the consumer will create a waiting future and wait for
// it
std::queue<Result<T>> queue;
util::optional<Future<T>> waiting_future;
// Every background task is given a future to complete when it is entirely finished
// processing and ready for the next task to start or for State to be destroyed
Future<> task_finished;
util::Mutex mutex;
};

class Task {
public:
void operator()(std::shared_ptr<State> state) {
// while condition can't be based on state_ because it is run outside the mutex
bool running = true;
while (running) {
auto next = state->it.Next();
// Need to capture state->waiting_future inside the mutex to mark finished outside
Future<T> waiting_future;
{
auto guard = state->mutex.Lock();
// Cleanup task that will be run when all consumer references to the generator are lost
struct Cleanup {
explicit Cleanup(State* state) : state(state) {}
~Cleanup() {
Future<> finish_fut;
{
auto lock = state->mutex.Lock();
if (!state->TaskIsRunning()) {
return;
}
// Signal the current task to stop and wait for it to finish
state->should_shutdown = true;
finish_fut = state->task_finished;
}
// Using future as a condition variable here
Status st = finish_fut.status();
ARROW_UNUSED(st);
}
State* state;
};

if (!next.ok() || IsIterationEnd<T>(*next)) {
state->finished = true;
state->running = false;
if (!next.ok()) {
state->ClearQueue();
}
}
if (state->waiting_future.has_value()) {
waiting_future = std::move(state->waiting_future.value());
state->waiting_future.reset();
} else {
state->queue.push(std::move(next));
if (static_cast<int>(state->queue.size()) >= state->max_q) {
state->running = false;
}
static void WorkerTask(std::shared_ptr<State> state) {
// We need to capture the state to read while outside the mutex
bool reading = true;
while (reading) {
auto next = state->it.Next();
// Need to capture state->waiting_future inside the mutex to mark finished outside
Future<T> waiting_future;
{
auto guard = state->mutex.Lock();

if (state->should_shutdown) {
state->finished = true;
break;
}

if (!next.ok() || IsIterationEnd<T>(*next)) {
// Terminal item. Mark finished to true, send this last item, and quit
state->finished = true;
if (!next.ok()) {
state->ClearQueue();
}
running = state->running;
}
// This must happen outside the task. Although presumably there is a transferring
// generator on the other end that will quickly transfer any callbacks off of this
// thread so we can continue looping. Still, best not to rely on that
if (waiting_future.is_valid()) {
waiting_future.MarkFinished(next);
// At this point we are going to send an item. Either we will add it to the
// queue or deliver it to a waiting future.
if (state->waiting_future.has_value()) {
waiting_future = std::move(state->waiting_future.value());
state->waiting_future.reset();
} else {
state->queue.push(std::move(next));
// We just filled up the queue so it is time to quit. We may need to notify
// a cleanup task so we transition to Quitting
if (static_cast<int>(state->queue.size()) >= state->max_q) {
state->reading = false;
}
}
reading = state->reading && !state->finished;
}
// This should happen outside the mutex. Presumably there is a
// transferring generator on the other end that will quickly transfer any
// callbacks off of this thread so we can continue looping. Still, best not to
// rely on that
if (waiting_future.is_valid()) {
waiting_future.MarkFinished(next);
}
}
};
// Once we've sent our last item we can notify any waiters that we are done and so
// either state can be cleaned up or a new background task can be started
Future<> task_finished;
{
auto guard = state->mutex.Lock();
// After we give up the mutex state can be safely deleted. We will no longer
// reference it. We can safely transition to idle now.
task_finished = state->task_finished;
state->task_finished = Future<>();
}
task_finished.MarkFinished();
}

std::shared_ptr<State> state_;
// state_ is held by both the generator and the background thread so it won't be cleaned
// up when all consumer references are relinquished. cleanup_ is only held by the
// generator so it will be destructed when the last consumer reference is gone. We use
// this to cleanup / stop the background generator in case the consuming end stops
// listening (e.g. due to a downstream error)
std::shared_ptr<Cleanup> cleanup_;
};

constexpr int kDefaultBackgroundMaxQ = 32;
Expand Down
59 changes: 59 additions & 0 deletions cpp/src/arrow/util/async_generator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,65 @@ TEST_P(BackgroundGeneratorTestFixture, StopAndRestart) {
AssertGeneratorExhausted(generator);
}

struct TrackingIterator {
explicit TrackingIterator(bool slow)
: token(std::make_shared<bool>(false)), slow(slow) {}

Result<TestInt> Next() {
if (slow) {
SleepABit();
}
return TestInt(0);
}
std::weak_ptr<bool> GetWeakTargetRef() { return std::weak_ptr<bool>(token); }

std::shared_ptr<bool> token;
bool slow;
};

TEST_P(BackgroundGeneratorTestFixture, AbortReading) {
// If there is an error downstream then it is likely the chain will abort and the
// background generator will lose all references and should abandon reading
TrackingIterator source(IsSlow());
auto tracker = source.GetWeakTargetRef();
auto iter = Iterator<TestInt>(std::move(source));
std::shared_ptr<AsyncGenerator<TestInt>> generator;
{
ASSERT_OK_AND_ASSIGN(
auto gen, MakeBackgroundGenerator(std::move(iter), internal::GetCpuThreadPool()));
generator = std::make_shared<AsyncGenerator<TestInt>>(gen);
}

// Poll one item to start it up
ASSERT_FINISHES_OK_AND_EQ(TestInt(0), (*generator)());
ASSERT_FALSE(tracker.expired());
// Remove last reference to generator, should trigger and wait for cleanup
generator.reset();
// Cleanup should have ensured no more reference to the source. It may take a moment
// to expire because the background thread has to destruct itself
BusyWait(10, [&tracker] { return tracker.expired(); });
}

TEST_P(BackgroundGeneratorTestFixture, AbortOnIdleBackground) {
// Tests what happens when the downstream aborts while the background thread is idle
ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1));

auto source = PossiblySlowVectorIt(RangeVector(100), IsSlow());
std::shared_ptr<AsyncGenerator<TestInt>> generator;
{
ASSERT_OK_AND_ASSIGN(auto gen,
MakeBackgroundGenerator(std::move(source), thread_pool.get()));
generator = std::make_shared<AsyncGenerator<TestInt>>(gen);
}
ASSERT_FINISHES_OK_AND_EQ(TestInt(0), (*generator)());

// The generator should pretty quickly fill up the queue and idle
BusyWait(10, [&thread_pool] { return thread_pool->GetNumTasks() == 0; });

// Now delete the generator and hope we don't deadlock
generator.reset();
}

struct SlowEmptyIterator {
Result<TestInt> Next() {
if (called_) {
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/util/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,12 @@ int ThreadPool::GetCapacity() {
return state_->desired_capacity_;
}

int ThreadPool::GetNumTasks() {
ProtectAgainstFork();
std::unique_lock<std::mutex> lock(state_->mutex_);
return state_->tasks_queued_or_running_;
}

int ThreadPool::GetActualCapacity() {
ProtectAgainstFork();
std::unique_lock<std::mutex> lock(state_->mutex_);
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/util/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ class ARROW_EXPORT ThreadPool : public Executor {
// match this value.
int GetCapacity() override;

// Return the number of tasks either running or in the queue.
int GetNumTasks();

// Dynamically change the number of worker threads.
//
// This function always returns immediately.
Expand Down