Skip to content
55 changes: 55 additions & 0 deletions cpp/src/arrow/testing/executor_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "arrow/util/thread_pool.h"

namespace arrow {

/// An executor which synchronously runs the task as part of the SpawnReal call.
class MockExecutor : public internal::Executor {
public:
int GetCapacity() override { return 0; }

Status SpawnReal(internal::TaskHints hints, internal::FnOnce<void()> task, StopToken,
StopCallback&&) override {
spawn_count++;
std::move(task)();
return Status::OK();
}

int spawn_count = 0;
};

/// An executor which does not actually run the task. Can be used to simulate situations
/// where the executor schedules a task in a long queue and doesn't get around to running
/// it for a while
class DelayedExecutor : public internal::Executor {
public:
int GetCapacity() override { return 0; }

Status SpawnReal(internal::TaskHints hints, internal::FnOnce<void()> task, StopToken,
StopCallback&&) override {
captured_tasks.push_back(std::move(task));
return Status::OK();
}

std::vector<internal::FnOnce<void()>> captured_tasks;
};

} // namespace arrow
66 changes: 55 additions & 11 deletions cpp/src/arrow/util/future.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
#include "arrow/util/thread_pool.h"

namespace arrow {

Expand Down Expand Up @@ -231,26 +232,68 @@ class ConcreteFutureImpl : public FutureImpl {

void DoMarkFailed() { DoMarkFinishedOrFailed(FutureState::FAILURE); }

void AddCallback(Callback callback) {
void CheckOptions(const CallbackOptions& opts) {
if (opts.should_schedule != ShouldSchedule::Never) {
DCHECK_NE(opts.executor, nullptr)
<< "An executor must be specified when adding a callback that might schedule";
}
}

void AddCallback(Callback callback, CallbackOptions opts) {
CheckOptions(opts);
std::unique_lock<std::mutex> lock(mutex_);
CallbackRecord callback_record{std::move(callback), opts};
if (IsFutureFinished(state_)) {
lock.unlock();
std::move(callback)();
RunOrScheduleCallback(std::move(callback_record), /*in_add_callback=*/true);
} else {
callbacks_.push_back(std::move(callback));
callbacks_.push_back(std::move(callback_record));
}
}

bool TryAddCallback(const std::function<Callback()>& callback_factory) {
bool TryAddCallback(const std::function<Callback()>& callback_factory,
CallbackOptions opts) {
CheckOptions(opts);
std::unique_lock<std::mutex> lock(mutex_);
if (IsFutureFinished(state_)) {
return false;
} else {
callbacks_.push_back(callback_factory());
callbacks_.push_back({callback_factory(), opts});
return true;
}
}

bool ShouldScheduleCallback(const CallbackRecord& callback_record,
bool in_add_callback) {
switch (callback_record.options.should_schedule) {
case ShouldSchedule::Never:
return false;
case ShouldSchedule::Always:
return true;
case ShouldSchedule::IfUnfinished:
return !in_add_callback;
default:
DCHECK(false) << "Unrecognized ShouldSchedule option";
return false;
}
}

void RunOrScheduleCallback(CallbackRecord&& callback_record, bool in_add_callback) {
if (ShouldScheduleCallback(callback_record, in_add_callback)) {
struct CallbackTask {
void operator()() { std::move(callback)(*self); }

Callback callback;
std::shared_ptr<FutureImpl> self;
};
// Need to keep `this` alive until the callback has a chance to be scheduled.
CallbackTask task{std::move(callback_record.callback), shared_from_this()};
DCHECK_OK(callback_record.options.executor->Spawn(std::move(task)));
} else {
std::move(callback_record.callback)(*this);
}
}

void DoMarkFinishedOrFailed(FutureState state) {
{
// Lock the hypothetical waiter first, and the future after.
Expand All @@ -272,8 +315,8 @@ class ConcreteFutureImpl : public FutureImpl {
//
// In fact, it is important not to hold the locks because the callback
// may be slow or do its own locking on other resources
for (auto&& callback : callbacks_) {
std::move(callback)();
for (auto& callback_record : callbacks_) {
RunOrScheduleCallback(std::move(callback_record), /*in_add_callback=*/false);
}
callbacks_.clear();
}
Expand Down Expand Up @@ -334,12 +377,13 @@ void FutureImpl::MarkFinished() { GetConcreteFuture(this)->DoMarkFinished(); }

void FutureImpl::MarkFailed() { GetConcreteFuture(this)->DoMarkFailed(); }

void FutureImpl::AddCallback(Callback callback) {
GetConcreteFuture(this)->AddCallback(std::move(callback));
void FutureImpl::AddCallback(Callback callback, CallbackOptions opts) {
GetConcreteFuture(this)->AddCallback(std::move(callback), opts);
}

bool FutureImpl::TryAddCallback(const std::function<Callback()>& callback_factory) {
return GetConcreteFuture(this)->TryAddCallback(callback_factory);
bool FutureImpl::TryAddCallback(const std::function<Callback()>& callback_factory,
CallbackOptions opts) {
return GetConcreteFuture(this)->TryAddCallback(callback_factory, opts);
}

Future<> AllComplete(const std::vector<Future<>>& futures) {
Expand Down
92 changes: 64 additions & 28 deletions cpp/src/arrow/util/future.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,30 @@ enum class FutureState : int8_t { PENDING, SUCCESS, FAILURE };

inline bool IsFutureFinished(FutureState state) { return state != FutureState::PENDING; }

/// \brief Describe whether the callback should be scheduled or run synchronously
enum class ShouldSchedule {
/// Always run the callback synchronously (the default)
Never = 0,
/// Schedule a new task only if the future is not finished when the
/// callback is added
IfUnfinished = 1,
/// Always schedule the callback as a new task
Always = 2
};

/// \brief Options that control how a continuation is run
struct CallbackOptions {
/// Describe whether the callback should be run synchronously or scheduled
ShouldSchedule should_schedule = ShouldSchedule::Never;
/// If the callback is scheduled then this is the executor it should be scheduled
/// on. If this is NULL then should_schedule must be Never
internal::Executor* executor = NULL;

static CallbackOptions Defaults() { return CallbackOptions(); }
};

// Untyped private implementation
class ARROW_EXPORT FutureImpl {
class ARROW_EXPORT FutureImpl : public std::enable_shared_from_this<FutureImpl> {
public:
FutureImpl();
virtual ~FutureImpl() = default;
Expand All @@ -218,10 +240,15 @@ class ARROW_EXPORT FutureImpl {
void MarkFailed();
void Wait();
bool Wait(double seconds);
template <typename ValueType>
Result<ValueType>* CastResult() const {
return static_cast<Result<ValueType>*>(result_.get());
}

using Callback = internal::FnOnce<void()>;
void AddCallback(Callback callback);
bool TryAddCallback(const std::function<Callback()>& callback_factory);
using Callback = internal::FnOnce<void(const FutureImpl& impl)>;
void AddCallback(Callback callback, CallbackOptions opts);
bool TryAddCallback(const std::function<Callback()>& callback_factory,
CallbackOptions opts);

// Waiter API
inline FutureState SetWaiter(FutureWaiter* w, int future_num);
Expand All @@ -234,7 +261,11 @@ class ARROW_EXPORT FutureImpl {
using Storage = std::unique_ptr<void, void (*)(void*)>;
Storage result_{NULLPTR, NULLPTR};

std::vector<Callback> callbacks_;
struct CallbackRecord {
Callback callback;
CallbackOptions options;
};
std::vector<CallbackRecord> callbacks_;
};

// An object that waits on multiple futures at once. Only one waiter
Expand Down Expand Up @@ -453,30 +484,34 @@ class Future {
/// cyclic reference to itself through the callback.
template <typename OnComplete>
typename std::enable_if<!detail::first_arg_is_status<OnComplete>::value>::type
AddCallback(OnComplete on_complete) const {
AddCallback(OnComplete on_complete,
CallbackOptions opts = CallbackOptions::Defaults()) const {
// We know impl_ will not be dangling when invoking callbacks because at least one
// thread will be waiting for MarkFinished to return. Thus it's safe to keep a
// weak reference to impl_ here
struct Callback {
void operator()() && { std::move(on_complete)(weak_self.get().result()); }
WeakFuture<T> weak_self;
void operator()(const FutureImpl& impl) && {
std::move(on_complete)(*impl.CastResult<ValueType>());
}
OnComplete on_complete;
};
impl_->AddCallback(Callback{WeakFuture<T>(*this), std::move(on_complete)});
impl_->AddCallback(Callback{std::move(on_complete)}, opts);
}

/// Overload for callbacks accepting a Status
template <typename OnComplete>
typename std::enable_if<detail::first_arg_is_status<OnComplete>::value>::type
AddCallback(OnComplete on_complete) const {
AddCallback(OnComplete on_complete,
CallbackOptions opts = CallbackOptions::Defaults()) const {
static_assert(std::is_same<internal::Empty, ValueType>::value,
"Callbacks for Future<> should accept Status and not Result");
struct Callback {
void operator()() && { std::move(on_complete)(weak_self.get().status()); }
WeakFuture<T> weak_self;
void operator()(const FutureImpl& impl) && {
std::move(on_complete)(impl.CastResult<ValueType>()->status());
}
OnComplete on_complete;
};
impl_->AddCallback(Callback{WeakFuture<T>(*this), std::move(on_complete)});
impl_->AddCallback(Callback{std::move(on_complete)}, opts);
}

/// \brief Overload of AddCallback that will return false instead of running
Expand All @@ -495,30 +530,33 @@ class Future {
template <typename CallbackFactory,
typename OnComplete = detail::result_of_t<CallbackFactory()>>
typename std::enable_if<!detail::first_arg_is_status<OnComplete>::value, bool>::type
TryAddCallback(const CallbackFactory& callback_factory) const {
TryAddCallback(const CallbackFactory& callback_factory,
CallbackOptions opts = CallbackOptions::Defaults()) const {
struct Callback {
void operator()() && { std::move(on_complete)(weak_self.get().result()); }
WeakFuture<T> weak_self;
void operator()(const FutureImpl& impl) && {
std::move(on_complete)(*static_cast<Result<ValueType>*>(impl.result_.get()));
}
OnComplete on_complete;
};
return impl_->TryAddCallback([this, &callback_factory]() {
return Callback{WeakFuture<T>(*this), callback_factory()};
});
return impl_->TryAddCallback(
[&callback_factory]() { return Callback{callback_factory()}; }, opts);
}

template <typename CallbackFactory,
typename OnComplete = detail::result_of_t<CallbackFactory()>>
typename std::enable_if<detail::first_arg_is_status<OnComplete>::value, bool>::type
TryAddCallback(const CallbackFactory& callback_factory) const {
TryAddCallback(const CallbackFactory& callback_factory,
CallbackOptions opts = CallbackOptions::Defaults()) const {
struct Callback {
void operator()() && { std::move(on_complete)(weak_self.get().status()); }
WeakFuture<T> weak_self;
void operator()(const FutureImpl& impl) && {
std::move(on_complete)(
static_cast<Result<ValueType>*>(impl.result_.get())->status());
}
OnComplete on_complete;
};

return impl_->TryAddCallback([this, &callback_factory]() {
return Callback{WeakFuture<T>(*this), callback_factory()};
});
return impl_->TryAddCallback(
[&callback_factory]() { return Callback{callback_factory()}; }, opts);
}

/// \brief Consumer API: Register a continuation to run when this future completes
Expand Down Expand Up @@ -696,9 +734,7 @@ class Future {

void Initialize() { impl_ = FutureImpl::Make(); }

Result<ValueType>* GetResult() const {
return static_cast<Result<ValueType>*>(impl_->result_.get());
}
Result<ValueType>* GetResult() const { return impl_->CastResult<ValueType>(); }

void SetResult(Result<ValueType> res) {
impl_->result_ = {new Result<ValueType>(std::move(res)),
Expand Down
Loading