Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions strings/base_coroutine_foundation.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,21 @@ namespace winrt::impl
};

template <typename Async>
struct await_adapter
struct await_adapter : enable_await_cancellation
{
await_adapter(Async const& async) : async(async) { }

Async const& async;
Windows::Foundation::AsyncStatus status = Windows::Foundation::AsyncStatus::Started;

void enable_cancellation(cancellable_promise* promise)
{
promise->set_canceller([](void* parameter)
{
cancel_asynchronously(reinterpret_cast<await_adapter*>(parameter)->async);
}, this);
}

bool await_ready() const noexcept
{
return false;
Expand All @@ -153,6 +163,19 @@ namespace winrt::impl
check_status_canceled(status);
return async.GetResults();
}

private:
static fire_and_forget cancel_asynchronously(Async async)
{
co_await winrt::resume_background();
try
{
async.Cancel();
}
catch (hresult_error const&)
{
}
}
};

template <typename D>
Expand Down Expand Up @@ -278,6 +301,11 @@ namespace winrt::impl
m_promise->cancellation_callback(std::move(cancel));
}

bool enable_propagation(bool value = true) const noexcept
{
return m_promise->enable_cancellation_propagation(value);
}

private:

Promise* m_promise;
Expand Down Expand Up @@ -414,6 +442,8 @@ namespace winrt::impl
{
cancel();
}

m_cancellable.cancel();
}

void Close() const noexcept
Expand Down Expand Up @@ -536,7 +566,7 @@ namespace winrt::impl
throw winrt::hresult_canceled();
}

return notify_awaiter<Expression>{ static_cast<Expression&&>(expression) };
return notify_awaiter<Expression>{ static_cast<Expression&&>(expression), m_propagate_cancellation ? &m_cancellable : nullptr };
}

cancellation_token<Derived> await_transform(get_cancellation_token_t) noexcept
Expand Down Expand Up @@ -567,6 +597,11 @@ namespace winrt::impl
}
}

bool enable_cancellation_propagation(bool value) noexcept
{
return std::exchange(m_propagate_cancellation, value);
}

#if defined(_DEBUG) && !defined(WINRT_NO_MAKE_DETECTION)
void use_make_function_to_create_this_object() final
{
Expand All @@ -587,8 +622,10 @@ namespace winrt::impl
slim_mutex m_lock;
async_completed_handler_t<AsyncInterface> m_completed;
winrt::delegate<> m_cancel;
cancellable_promise m_cancellable;
std::atomic<AsyncStatus> m_status;
bool m_completed_assigned{ false };
bool m_propagate_cancellation{ false };
};
}

Expand Down
171 changes: 162 additions & 9 deletions strings/base_coroutine_threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,81 @@ namespace winrt::impl
static constexpr bool has_co_await_member = find_co_await_member<T&&>(0);
static constexpr bool has_co_await_free = find_co_await_free<T&&>(0);
};
}

WINRT_EXPORT namespace winrt
{
struct cancellable_promise
{
using canceller_t = void(*)(void*);

void set_canceller(canceller_t canceller, void* context)
{
m_context = context;
m_canceller.store(canceller, std::memory_order_release);
}

void revoke_canceller()
{
while (m_canceller.exchange(nullptr, std::memory_order_acquire) == cancelling_ptr)
{
std::this_thread::yield();
}
}

void cancel()
{
auto canceller = m_canceller.exchange(cancelling_ptr, std::memory_order_acquire);
struct unique_cancellation_lock
{
cancellable_promise* promise;
~unique_cancellation_lock()
{
promise->m_canceller.store(nullptr, std::memory_order_release);
}
} lock{ this };

if ((canceller != nullptr) && (canceller != cancelling_ptr))
{
canceller(m_context);
}
}

private:
static inline auto const cancelling_ptr = reinterpret_cast<canceller_t>(1);

std::atomic<canceller_t> m_canceller{ nullptr };
void* m_context{ nullptr };
};

struct enable_await_cancellation
{
enable_await_cancellation() noexcept = default;
enable_await_cancellation(enable_await_cancellation const&) = delete;

~enable_await_cancellation()
{
if (m_promise)
{
m_promise->revoke_canceller();
}
}

void operator=(enable_await_cancellation const&) = delete;

void set_cancellable_promise(cancellable_promise* promise) noexcept
{
m_promise = promise;
}

private:

cancellable_promise* m_promise = nullptr;
};
}

namespace winrt::impl
{
template <typename T>
decltype(auto) get_awaiter(T&& value) noexcept
{
Expand All @@ -149,8 +223,16 @@ namespace winrt::impl
{
decltype(get_awaiter(std::declval<T&&>())) awaitable;

notify_awaiter(T&& awaitable) : awaitable(get_awaiter(static_cast<T&&>(awaitable)))
notify_awaiter(T&& awaitable_arg, cancellable_promise* promise = nullptr) : awaitable(get_awaiter(static_cast<T&&>(awaitable_arg)))
{
if constexpr (std::is_convertible_v<std::remove_reference_t<decltype(awaitable)>&, enable_await_cancellation&>)
{
if (promise)
{
static_cast<enable_await_cancellation&>(awaitable).set_cancellable_promise(promise);
awaitable.enable_cancellation(promise);
}
}
}

bool await_ready()
Expand Down Expand Up @@ -271,34 +353,67 @@ WINRT_EXPORT namespace winrt

[[nodiscard]] inline auto resume_after(Windows::Foundation::TimeSpan duration) noexcept
{
struct awaitable
struct awaitable : enable_await_cancellation
{
explicit awaitable(Windows::Foundation::TimeSpan duration) noexcept :
m_duration(duration)
{
}

void enable_cancellation(cancellable_promise* promise)
{
promise->set_canceller([](void* context)
{
auto that = static_cast<awaitable*>(context);
if (that->m_state.exchange(state::canceled, std::memory_order_acquire) == state::pending)
{
that->fire_immediately();
}
}, this);
}

bool await_ready() const noexcept
{
return m_duration.count() <= 0;
}

void await_suspend(std::experimental::coroutine_handle<> handle)
{
m_timer.attach(check_pointer(WINRT_IMPL_CreateThreadpoolTimer(callback, handle.address(), nullptr)));
m_handle = handle;
m_timer.attach(check_pointer(WINRT_IMPL_CreateThreadpoolTimer(callback, this, nullptr)));
int64_t relative_count = -m_duration.count();
WINRT_IMPL_SetThreadpoolTimer(m_timer.get(), &relative_count, 0, 0);
WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), &relative_count, 0, 0);

state expected = state::idle;
if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release))
{
fire_immediately();
}
}

void await_resume() const noexcept
void await_resume()
{
if (m_state.exchange(state::idle, std::memory_order_relaxed) == state::canceled)
{
throw hresult_canceled();
}
}

private:

void fire_immediately() noexcept
{
if (WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), nullptr, 0, 0))
{
int64_t now = 0;
WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), &now, 0, 0);
}
}

static void __stdcall callback(void*, void* context, void*) noexcept
{
std::experimental::coroutine_handle<>::from_address(context)();
auto that = reinterpret_cast<awaitable*>(context);
that->m_handle();
}

struct timer_traits
Expand All @@ -316,8 +431,12 @@ WINRT_EXPORT namespace winrt
}
};

enum class state { idle, pending, canceled };

handle_type<timer_traits> m_timer;
Windows::Foundation::TimeSpan m_duration;
std::experimental::coroutine_handle<> m_handle;
std::atomic<state> m_state{ state::idle };
};

return awaitable{ duration };
Expand All @@ -332,13 +451,25 @@ WINRT_EXPORT namespace winrt

[[nodiscard]] inline auto resume_on_signal(void* handle, Windows::Foundation::TimeSpan timeout = {}) noexcept
{
struct awaitable
struct awaitable : enable_await_cancellation
{
awaitable(void* handle, Windows::Foundation::TimeSpan timeout) noexcept :
m_timeout(timeout),
m_handle(handle)
{}

void enable_cancellation(cancellable_promise* promise)
{
promise->set_canceller([](void* context)
{
auto that = static_cast<awaitable*>(context);
if (that->m_state.exchange(state::canceled, std::memory_order_acquire) == state::pending)
{
that->fire_immediately();
}
}, this);
}

bool await_ready() const noexcept
{
return WINRT_IMPL_WaitForSingleObject(m_handle, 0) == 0;
Expand All @@ -350,16 +481,35 @@ WINRT_EXPORT namespace winrt
m_wait.attach(check_pointer(WINRT_IMPL_CreateThreadpoolWait(callback, this, nullptr)));
int64_t relative_count = -m_timeout.count();
int64_t* file_time = relative_count != 0 ? &relative_count : nullptr;
WINRT_IMPL_SetThreadpoolWait(m_wait.get(), m_handle, file_time);
WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), m_handle, file_time, nullptr);

state expected = state::idle;
if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release))
{
fire_immediately();
}
}

bool await_resume() const noexcept
bool await_resume()
{
if (m_state.exchange(state::idle, std::memory_order_relaxed) == state::canceled)
{
throw hresult_canceled();
}
return m_result == 0;
}

private:

void fire_immediately() noexcept
{
if (WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), nullptr, nullptr, nullptr))
{
int64_t now = 0;
WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), WINRT_IMPL_GetCurrentProcess(), &now, nullptr);
}
}

static void __stdcall callback(void*, void* context, void*, uint32_t result) noexcept
{
auto that = static_cast<awaitable*>(context);
Expand All @@ -382,11 +532,14 @@ WINRT_EXPORT namespace winrt
}
};

enum class state { idle, pending, canceled };

handle_type<wait_traits> m_wait;
Windows::Foundation::TimeSpan m_timeout;
void* m_handle;
uint32_t m_result{};
std::experimental::coroutine_handle<> m_resume{ nullptr };
std::atomic<state> m_state{ state::idle };
};

return awaitable{ handle, timeout };
Expand Down
8 changes: 4 additions & 4 deletions strings/base_extern.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ extern "C"

int32_t __stdcall WINRT_IMPL_TrySubmitThreadpoolCallback(void(__stdcall *callback)(void*, void* context), void* context, void*) noexcept;
winrt::impl::ptp_timer __stdcall WINRT_IMPL_CreateThreadpoolTimer(void(__stdcall *callback)(void*, void* context, void*), void* context, void*) noexcept;
void __stdcall WINRT_IMPL_SetThreadpoolTimer(winrt::impl::ptp_timer timer, void* time, uint32_t period, uint32_t window) noexcept;
int32_t __stdcall WINRT_IMPL_SetThreadpoolTimerEx(winrt::impl::ptp_timer timer, void* time, uint32_t period, uint32_t window) noexcept;
void __stdcall WINRT_IMPL_CloseThreadpoolTimer(winrt::impl::ptp_timer timer) noexcept;
winrt::impl::ptp_wait __stdcall WINRT_IMPL_CreateThreadpoolWait(void(__stdcall *callback)(void*, void* context, void*, uint32_t result), void* context, void*) noexcept;
void __stdcall WINRT_IMPL_SetThreadpoolWait(winrt::impl::ptp_wait wait, void* handle, void* timeout) noexcept;
int32_t __stdcall WINRT_IMPL_SetThreadpoolWaitEx(winrt::impl::ptp_wait wait, void* handle, void* timeout, void* reserved) noexcept;
void __stdcall WINRT_IMPL_CloseThreadpoolWait(winrt::impl::ptp_wait wait) noexcept;
winrt::impl::ptp_io __stdcall WINRT_IMPL_CreateThreadpoolIo(void* object, void(__stdcall *callback)(void*, void* context, void* overlapped, uint32_t result, std::size_t bytes, void*) noexcept, void* context, void*) noexcept;
void __stdcall WINRT_IMPL_StartThreadpoolIo(winrt::impl::ptp_io io) noexcept;
Expand Down Expand Up @@ -147,10 +147,10 @@ WINRT_IMPL_LINK(WaitForSingleObject, 8)

WINRT_IMPL_LINK(TrySubmitThreadpoolCallback, 12)
WINRT_IMPL_LINK(CreateThreadpoolTimer, 12)
WINRT_IMPL_LINK(SetThreadpoolTimer, 16)
WINRT_IMPL_LINK(SetThreadpoolTimerEx, 16)
WINRT_IMPL_LINK(CloseThreadpoolTimer, 4)
WINRT_IMPL_LINK(CreateThreadpoolWait, 12)
WINRT_IMPL_LINK(SetThreadpoolWait, 12)
WINRT_IMPL_LINK(SetThreadpoolWaitEx, 16)
WINRT_IMPL_LINK(CloseThreadpoolWait, 4)
WINRT_IMPL_LINK(CreateThreadpoolIo, 16)
WINRT_IMPL_LINK(StartThreadpoolIo, 4)
Expand Down
1 change: 1 addition & 0 deletions strings/base_includes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <stdexcept>
#include <string_view>
#include <string>
#include <thread>
#include <tuple>
#include <type_traits>
#include <unordered_map>
Expand Down
Loading