From 6a3f7727f8dd68be04c4b0392c54e6f283d17c3a Mon Sep 17 00:00:00 2001 From: Raymond Chen Date: Mon, 3 Aug 2020 22:14:55 -0700 Subject: [PATCH 1/2] Allow cancellation to be propagated to child coroutines The cancellation token learned a new method `enable_propagation()`. It takes an optional boolean parameter (default true) which specifies whether propagation of cancellation into child coroutines is enabled. It returns the previous setting, in case you want to restore it. If cancellation propagation is enabled in a coroutine, then calling `IAsyncXxx.Cancel()` on the coroutine's asynchronous activity will try to propagate the cancellation into whatever the coroutine is `co_await`ing for. Currently, the following awaitables are supported: * `IAsyncXxx` * `resume_after()` * `resume_on_signal()` Example: ```cpp IAsyncAction CheckAsync() { // Enable cancellation propagation. auto cancel = co_await get_cancellation_token(); cancel.enable_propagation(); HttpClient client; auto result = co_await client.TryGetAsync(Uri(L"https://www.microsoft.com")); co_return result.Succeeded() && result.ResponseMessage().IsSuccessStatusCode(); } IAsyncAction checkOperation; fire_and_forget StartButton_Click() { checkOperation = CheckAsync(); try { if (co_await checkOperation) { // Everything is just fine. } else { // Couldn't connect to server. } } catch (hresult_canceled const&) { // Operation was canceled. } } void CancelButton_Click() { checkOperation.Cancel(); } ``` When the user clicks the Start button, we begin the `CheckAsync()` operation and save it. If the operation is taking too long and the user clicks the Cancel button, we Cancel the operation. Normally, this would cancel the `CheckAsync()` operation, but that's all. The `TryGetAsync()` call inside `CheckAsync()` continues to run to its normal completion. The `CheckAsync()` operation has been cancelled, and the `catch` clause runs, but the `TryGetAsync()` call is still running. Enabling cancellation propagation means that when the `CheckAsync()` call is canceled, the `TryGetAsync()` call is also canceled. Cancellation propagation allows a coroutine to respond more quickly to cancellation. Without cancellation propagation, the coroutine would have to wait for the `TryGetAsync()` to complete naturally (probably due to network timeout) before it could clean up. In some cases, the natural completion of the awaitable could be very long (`resume_after(24h)`) or may never happen at all (`resume_on_signal(never_signaled)`), causing the coroutine to become effectively leaked, since it is waiting for something that will never happen. > **Note**: All names are provisional. Suggestions for better > names are gratefully welcomed. An awaiter can mark itself as supporting cancellation propagation by inheriting publically from `enable_await_cancellation`. (This is analogous to `enable_shared_from_this`.) If the awaiter is created by a coroutine in which cancellation propagation has been enabled, its `enable_cancellation()` method is called with a `cancellable_promise` pointer. The `cancellable_promise` represents a hook into the cancellation of a coroutine, so that an awaiter can cancel the awaited-for object when the calling coroutine is cancelled. The awaiter's `enable_cancellation()` calls the cancellable promise's `set_cancelled()` method, passing a callback function pointer and a context pointer (typically the awaiter's `this` pointer). If the outer coroutine is cancelled, then the callback function is called with the context pointer as a parameter. The callback function should arrange for the cancellation of the awaited-on object so that the `co_await` can complete. It should also arrange that the `await_resume()` method throws the `hresult_canceled()` exception. The cancellation callback is called under a spinlock, so it should work quickly. If it needs to do expensive work, it should finish the work asynchronously. Note that once the cancellation callback returns, the awaiter is eligible for destruction, so make sure to extend the lifetimes of any objects you need if they would normally be destructed by the awaiter. An example of this pattern can be found in the standard awaiter for `IAsyncAction` (known internally as `await_adapter`). Here is a simplified version: ```cpp template struct await_adapter : enable_await_cancellation { await_adapter(Async const& async) : async(async) { } Async const& async; void enable_cancellation(cancellable_promise* promise) { promise->set_canceller([](void* context) { cancel_asynchronously(reinterpret_cast(context)->async); }, this); fire_and_forget cancel_asynchronously(Async async) { co_await resume_background(); async.Cancel(); } ... other awaiter stuff unchanged ... }; ``` When cancellation is enabled, we set a lambda as our canceller, with the `await_adapter` itself as the context parameter. If cancellation is indeed requested, we make a copy of the `async` (which is a single call to `AddRef`, which should be fast), and pass it to `cancel_asynchronously()` to continue the work on a background thread. (We cleverly use coroutines to help implement coroutines.) Cancelling on a background thread solves two problems. The first is that it allows execution on the main thread to continue, releasing the spinlock. The second is that it avoids deep recursion, because cancelling the awaited-for asynchronous activity could very well trigger another cancellation propagation if that asynchronous activity was itself waiting for something else. The cancellation propagates all the way down the call chain until it reaches a primitive like `resume_on_signal()`, or reaches something that does not support cancellation propagation. We do not need to make any changes to `await_resume()` because the asynchronous operation will complete with the `Canceled` status, and the existing code already converts that to an `hresult_canceled` exception. The actual implementation is slightly more complicated in that the call to `async.Cancel()` is itself done under a try/catch to deal with the possibility that cancellation fails, for example, because the RPC server has died. The cancellation is best effort, and in the case of RPC server death, the loss of the operation will be detected by the `disconnect_aware_handler` and converted into an RPC exception. Awaiters for custom objects may also derive from `enable_await_cancellation` in order to participate in cancellation propagation. For example: ```cpp struct io_operation { struct awaitable : enable_await_cancellation { io_operation& m_op; awaitable(io_operation& op) : m_op(op) { } void enable_cancellation(cancellable_promise* promise) { promise->set_canceller([](void* context) { auto op = static_cast(context); CancelIoEx(op->handle, &op->overlapped); }, std::addressof(m_op)); } ... other awaitable members ... }; auto operator co_await() { return awaitable{ *this }; } ... other io_operation members ... }; ``` The awaitable needs to be ready to cancel the operation even before `await_ready()` is called, or after `await_resume()` has been called. (In the latter case, of course, the operation has already completed, so any attempt to cancel it is moot. Nevertheless, the code needs to handle the case of a too-late cancellations.) ## Implementation details One of the main motivations behind the design is to keep the "no cancellation" code path fast. Cancellation is extremely rare, so it is better to make the "no cancellation" code path fast, even if it comes at the expense of making the cancellation code path slower or more complicated. We use a callback function and pointer because they are lighter weight than an object with a virtual method. The portion of the coroutine promise that is exposed to clients is the `cancellable_promise`. The methods are as follows: * `set_canceller`: Called by the implementation of the awaiter. * `revoke_canceller`: Call by the destructor of `enable_cancellation` to clean up when the awaiter destructs. * `cancel`: Called by the promise when the enclosing coroutine is cancelled. The canceller function pointer can hold one of three values: * `nullptr`: No canceller has been set, or cancellation has already been attempted. * canceller: The canceller to call. * `cancelling_ptr`: Sentinel value that indicates that cancellation is in progress. The `set_canceller` method sets the canceller pointer with release semantics so that the state of the awaiter is published to the cancelling thread. The `revoke_canceller` method sets the canceller to `nullptr`, but spins if the previous value was `cancelling_ptr`, indicating that we should not destruct the awaiter because cancellation is in progress. It uses acquire memory order so that the state of the awaiter is properly received from the cancelling thread (if any) before we destruct it. The `cancel` method atomically obtains the current canceller and sets the canceller to the `cancelling_ptr` sentinel value to indicate that cancellation is in progress. This exchange is performed with acquire semantics to ensure that the state of the awaiter is properly received from the publishing thread before we call the callback. The `cancel` method then uses an RAII type to ensure that the canceller returns to `nullptr` when the cancel method is done. An RAII type is important here, so that we properly clean up even if the canceller callback throws an exception. If there is a canceller callback, we call it. (If the function pointer is `cancelling_ptr`, then it means that cancellation is already in progress, so we shouldn't try to cancel again.) The RAII type returns the canceller to `nullptr` with release semantics so that the state of the awaiter is published to awaiting thread. The `enable_await_cancellation` class serves both as a marker class (so that we know that the awaiter supports cancellation propagation) as well as handling the bookkeeping of revoking the canceller when the awaiter destructs. Detecting the `enable_await_cancellation` is done by `is_convertible_v` to avoid edge cases like `operator co_await` returning a const object as an awaiter. (It also means that an awaiter can delegate the cancellation to a sub-object if it wants to take finer control over object layout.) The `notify_awaiter` is augmented to receive a `cancellable_promise` pointer as part of its construction, which is propagated into the wrapped awaiter's `enable_await_cancellation` if present. This also triggers the call to `enable_cancellation`. The `promise_base` contains a `cancellable_promise`, which it passes to the `notify_awaiter` constructor if cancellation propagation is enabled. It calls the `cancellable_promise::cancel()` method when the coroutine is cancelled in order to propagate the cancellation. The `cancellation_token` has a new `enable_propagation` method which calls back into the promise to specify whether cancellation propagation is now enabled or not, and it returns the previous state, in case you want to restore it. Cancellation propagation support for `resume_after` and `resume_on_signal` are very similar. To avoid race conditions if the coroutine is cancelled while `await_suspend` is still setting up the thread pool objects, we track the state of the awaiter in an atomic variable whose values are either `idle` (no thread pool object is active), `pending` (waiting for thread pool object to queue a callback), and `canceled` (the operation has been cancelled). After creating the thread pool objects, we transition from `idle` to `pending` and accelerate to completion if the coroutine has already been cancelled. This transition uses release semantics so that the state of the awaiter is published to any potential cancelling thread. When resuming, we transition back to `idle` and throw the `hresult_canceled()` exception if we had been cancelled. This transition can be done with relaxed semantics because the cancellation callback does not mutate the awaiter. If cancellation is requested, we transition unconditionally to the cancelled state, and if the previous state was pending, we also have to accelerate to completion so that we can get out of the pending state. This transition requires acquire semantics because we need to read the awaiter state that was published by `await_suspend`. To accelerate to completion, we first call the corresponding `Set...Ex` function with the magic parameters that cancel the thread pool callback. The possible return values are * If the thread pool object hasn't been set yet, then it returns `FALSE`. This case cannot happen because we don't transition to `pending` state until after we set the thread pool callback. * If the thread pool object was set and hasn't been signaled, then it returns `TRUE`. In this case, we need to force a new callback. * If the thread pool object was set and already scheduled a thread pool callback (which may or not have been called yet), then it returns `FALSE`. In this case, we don't need to force a callback, because one is on its way or has already occurred. Therefore, we need to force a new callback if the `Set...Ex` function returns `TRUE`. For timers, we do this by resetting the timer with a due time of zero, meaning "now". For waits, we do this by resetting the wait with the current process handle (which will never be signaled) and a timeout of zero, meaning "now". The subtlety in the case of a wait is that we use the current process handle rather than the original handle, because the original handle may have gone invalid in the meantime, and we don't want the `SetThreadpoolWaitEx` to fail with `ERROR_INVALID_HANDLE`. Additional cost for `co_await` is a boolean test (to see whether cancellation propagation is enabled) that is used to decide whether to pass the address of a member variable or a null pointer. That pointer is then tested for non-null, and if non-null, the `set_canceller` method is called to register the cancellation callbacks. This pointer is also checked when the awaiter is destructed, and if non-null, then cleanup is performed. Cost summary: A bool (slips into padding) and two pointers in the promise. If cancellation propagation is not enabled, a boolean test, a memory write of a null pointer, and two null pointer tests. Constant propagation can remove the first null pointer test. The memory write is to the awaiter, which is being constructed, so it will share the cache line with the other members of the awaiter. If cancellation propagation is enabled, but no cancellation occurs, then a boolean test, a null pointer test, writing two pointers to memory, and a release fence; when the await completes, a null pointer test and an acquire fence. If cancellation propagation is enabled, and cancellation occurs, then the cost goes up significantly, but that is expected because that's the point of the feature. Code generation for MSVC-x64 at await: xor rbx, rbx lea rcx, m_canceller cmp byte ptr m_propagate_cancellation, 0 cmove rcx, rbx ; rcx = nullptr or &m_canceller mov awaiter.m_promise, rbx (initialize other members of the awaiter) test rcx, rcx ; Q: is cancellation propagation enabled? jz skip ; N: Don't register for cancellation lea rax, awaiter ; m_context for canceller mov awaiter.m_promise, rcx ; awaiter.m_promise = &m_canceller mov awaiter.m_context, rax ; cancellation context lea rax, [callback] ; callback function mov awaiter.m_canceller, rax skip: Profile-guided optimization may help the compiler decide which branch to optimize as the fallthrough case. Code generation at resume (awaiter destruction): mov rbx, m_promise test rbx, rbx ; Q: need to clear the callback? je done ; N: Nothing to do xor eax, eax xchg rax, m_promise->m_canceller ; try to clear the callback cmp rax, 1 ; Q: racing against canceller? jne done ; N: nothing to do spin: call _Thrd_yield ; Let the canceller finish xor eax, eax xchg rax, m_promise->m_canceller ; try to clear the callback cmp rax, 1 ; Q: racing against canceller? je spin done: Profile-guided optimization may help the compiler realize that spinning almost never happens. Fixes #690 --- strings/base_coroutine_foundation.h | 41 ++++++- strings/base_coroutine_threadpool.h | 171 +++++++++++++++++++++++++-- strings/base_extern.h | 8 +- strings/base_includes.h | 1 + test/test/async_propagate_cancel.cpp | 139 ++++++++++++++++++++++ test/test/test.vcxproj | 1 + 6 files changed, 346 insertions(+), 15 deletions(-) create mode 100644 test/test/async_propagate_cancel.cpp diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index a3862ad6d..666268015 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -128,11 +128,21 @@ namespace winrt::impl }; template - struct await_adapter + struct await_adapter : enable_await_cancellation { + await_adapter(Async const& async) : async(async) { } + Async const& async; Windows::Foundation::AsyncStatus status = Windows::Foundation::AsyncStatus::Started; + void enable_cancellation(cancellable_promise* promise) + { + promise->set_canceller([](void* parameter) + { + cancel_asynchronously(reinterpret_cast(parameter)->async); + }, this); + } + bool await_ready() const noexcept { return false; @@ -153,6 +163,19 @@ namespace winrt::impl check_status_canceled(status); return async.GetResults(); } + + private: + static fire_and_forget cancel_asynchronously(Async async) + { + co_await winrt::resume_background(); + try + { + async.Cancel(); + } + catch (hresult_error const&) + { + } + } }; template @@ -278,6 +301,11 @@ namespace winrt::impl m_promise->cancellation_callback(std::move(cancel)); } + bool enable_propagation(bool value = true) const noexcept + { + return m_promise->enable_cancellation_propagation(value); + } + private: Promise* m_promise; @@ -414,6 +442,8 @@ namespace winrt::impl { cancel(); } + + m_cancellable.cancel(); } void Close() const noexcept @@ -536,7 +566,7 @@ namespace winrt::impl throw winrt::hresult_canceled(); } - return notify_awaiter{ static_cast(expression) }; + return notify_awaiter{ static_cast(expression), m_propagate_cancellation ? &m_cancellable : nullptr }; } cancellation_token await_transform(get_cancellation_token_t) noexcept @@ -567,6 +597,11 @@ namespace winrt::impl } } + bool enable_cancellation_propagation(bool value) noexcept + { + return std::exchange(m_propagate_cancellation, value); + } + #if defined(_DEBUG) && !defined(WINRT_NO_MAKE_DETECTION) void use_make_function_to_create_this_object() final { @@ -587,8 +622,10 @@ namespace winrt::impl slim_mutex m_lock; async_completed_handler_t m_completed; winrt::delegate<> m_cancel; + cancellable_promise m_cancellable; std::atomic m_status; bool m_completed_assigned{ false }; + bool m_propagate_cancellation{ false }; }; } diff --git a/strings/base_coroutine_threadpool.h b/strings/base_coroutine_threadpool.h index 3b1e64506..ac309b377 100644 --- a/strings/base_coroutine_threadpool.h +++ b/strings/base_coroutine_threadpool.h @@ -124,7 +124,81 @@ namespace winrt::impl static constexpr bool has_co_await_member = find_co_await_member(0); static constexpr bool has_co_await_free = find_co_await_free(0); }; +} + +WINRT_EXPORT namespace winrt +{ + struct cancellable_promise + { + using canceller_t = void(*)(void*); + + void set_canceller(canceller_t canceller, void* context) + { + m_context = context; + m_canceller.store(canceller, std::memory_order_release); + } + + void revoke_canceller() + { + while (m_canceller.exchange(nullptr, std::memory_order_acquire) == cancelling_ptr) + { + std::this_thread::yield(); + } + } + + void cancel() + { + auto canceller = m_canceller.exchange(cancelling_ptr, std::memory_order_acquire); + struct unique_cancellation_lock + { + cancellable_promise* promise; + ~unique_cancellation_lock() + { + promise->m_canceller.store(nullptr, std::memory_order_release); + } + } lock{ this }; + + if ((canceller != nullptr) && (canceller != cancelling_ptr)) + { + canceller(m_context); + } + } + + private: + static inline auto const cancelling_ptr = reinterpret_cast(1); + + std::atomic m_canceller{ nullptr }; + void* m_context{ nullptr }; + }; + + struct enable_await_cancellation + { + enable_await_cancellation() noexcept = default; + enable_await_cancellation(enable_await_cancellation const&) = delete; + + ~enable_await_cancellation() + { + if (m_promise) + { + m_promise->revoke_canceller(); + } + } + + void operator=(enable_await_cancellation const&) = delete; + + void set_cancellable_promise(cancellable_promise* promise) noexcept + { + m_promise = promise; + } + + private: + + cancellable_promise* m_promise = nullptr; + }; +} +namespace winrt::impl +{ template decltype(auto) get_awaiter(T&& value) noexcept { @@ -149,8 +223,16 @@ namespace winrt::impl { decltype(get_awaiter(std::declval())) awaitable; - notify_awaiter(T&& awaitable) : awaitable(get_awaiter(static_cast(awaitable))) + notify_awaiter(T&& awaitable_arg, cancellable_promise* promise = nullptr) : awaitable(get_awaiter(static_cast(awaitable_arg))) { + if constexpr (std::is_convertible_v&, enable_await_cancellation&>) + { + if (promise) + { + static_cast(awaitable).set_cancellable_promise(promise); + awaitable.enable_cancellation(promise); + } + } } bool await_ready() @@ -271,13 +353,25 @@ WINRT_EXPORT namespace winrt [[nodiscard]] inline auto resume_after(Windows::Foundation::TimeSpan duration) noexcept { - struct awaitable + struct awaitable : enable_await_cancellation { explicit awaitable(Windows::Foundation::TimeSpan duration) noexcept : m_duration(duration) { } + void enable_cancellation(cancellable_promise* promise) + { + promise->set_canceller([](void* context) + { + auto that = static_cast(context); + if (that->m_state.exchange(state::canceled, std::memory_order_acquire) == state::pending) + { + that->fire_immediately(); + } + }, this); + } + bool await_ready() const noexcept { return m_duration.count() <= 0; @@ -285,20 +379,41 @@ WINRT_EXPORT namespace winrt void await_suspend(std::experimental::coroutine_handle<> handle) { - m_timer.attach(check_pointer(WINRT_IMPL_CreateThreadpoolTimer(callback, handle.address(), nullptr))); + m_handle = handle; + m_timer.attach(check_pointer(WINRT_IMPL_CreateThreadpoolTimer(callback, this, nullptr))); int64_t relative_count = -m_duration.count(); - WINRT_IMPL_SetThreadpoolTimer(m_timer.get(), &relative_count, 0, 0); + WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), &relative_count, 0, 0, nullptr); + + state expected = state::idle; + if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release)) + { + fire_immediately(); + } } - void await_resume() const noexcept + void await_resume() { + if (m_state.exchange(state::idle, std::memory_order_relaxed) == state::canceled) + { + throw hresult_canceled(); + } } private: + void fire_immediately() noexcept + { + if (WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), nullptr, 0, 0, nullptr)) + { + int64_t now = 0; + WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), &now, 0, 0, nullptr); + } + } + static void __stdcall callback(void*, void* context, void*) noexcept { - std::experimental::coroutine_handle<>::from_address(context)(); + auto that = reinterpret_cast(context); + that->m_handle(); } struct timer_traits @@ -316,8 +431,12 @@ WINRT_EXPORT namespace winrt } }; + enum class state { idle, pending, canceled }; + handle_type m_timer; Windows::Foundation::TimeSpan m_duration; + std::experimental::coroutine_handle<> m_handle; + std::atomic m_state{ state::idle }; }; return awaitable{ duration }; @@ -332,13 +451,25 @@ WINRT_EXPORT namespace winrt [[nodiscard]] inline auto resume_on_signal(void* handle, Windows::Foundation::TimeSpan timeout = {}) noexcept { - struct awaitable + struct awaitable : enable_await_cancellation { awaitable(void* handle, Windows::Foundation::TimeSpan timeout) noexcept : m_timeout(timeout), m_handle(handle) {} + void enable_cancellation(cancellable_promise* promise) + { + promise->set_canceller([](void* context) + { + auto that = static_cast(context); + if (that->m_state.exchange(state::canceled, std::memory_order_acquire) == state::pending) + { + that->fire_immediately(); + } + }, this); + } + bool await_ready() const noexcept { return WINRT_IMPL_WaitForSingleObject(m_handle, 0) == 0; @@ -350,16 +481,35 @@ WINRT_EXPORT namespace winrt m_wait.attach(check_pointer(WINRT_IMPL_CreateThreadpoolWait(callback, this, nullptr))); int64_t relative_count = -m_timeout.count(); int64_t* file_time = relative_count != 0 ? &relative_count : nullptr; - WINRT_IMPL_SetThreadpoolWait(m_wait.get(), m_handle, file_time); + WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), m_handle, file_time, nullptr); + + state expected = state::idle; + if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release)) + { + fire_immediately(); + } } - bool await_resume() const noexcept + bool await_resume() { + if (m_state.exchange(state::idle, std::memory_order_relaxed) == state::canceled) + { + throw hresult_canceled(); + } return m_result == 0; } private: + void fire_immediately() noexcept + { + if (WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), nullptr, nullptr, nullptr)) + { + int64_t now = 0; + WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), WINRT_IMPL_GetCurrentProcess(), &now, nullptr); + } + } + static void __stdcall callback(void*, void* context, void*, uint32_t result) noexcept { auto that = static_cast(context); @@ -382,11 +532,14 @@ WINRT_EXPORT namespace winrt } }; + enum class state { idle, pending, canceled }; + handle_type m_wait; Windows::Foundation::TimeSpan m_timeout; void* m_handle; uint32_t m_result{}; std::experimental::coroutine_handle<> m_resume{ nullptr }; + std::atomic m_state{ state::idle }; }; return awaitable{ handle, timeout }; diff --git a/strings/base_extern.h b/strings/base_extern.h index 59e89b610..859956f77 100644 --- a/strings/base_extern.h +++ b/strings/base_extern.h @@ -63,10 +63,10 @@ extern "C" int32_t __stdcall WINRT_IMPL_TrySubmitThreadpoolCallback(void(__stdcall *callback)(void*, void* context), void* context, void*) noexcept; winrt::impl::ptp_timer __stdcall WINRT_IMPL_CreateThreadpoolTimer(void(__stdcall *callback)(void*, void* context, void*), void* context, void*) noexcept; - void __stdcall WINRT_IMPL_SetThreadpoolTimer(winrt::impl::ptp_timer timer, void* time, uint32_t period, uint32_t window) noexcept; + int32_t __stdcall WINRT_IMPL_SetThreadpoolTimerEx(winrt::impl::ptp_timer timer, void* time, uint32_t period, uint32_t window, void* reserved) noexcept; void __stdcall WINRT_IMPL_CloseThreadpoolTimer(winrt::impl::ptp_timer timer) noexcept; winrt::impl::ptp_wait __stdcall WINRT_IMPL_CreateThreadpoolWait(void(__stdcall *callback)(void*, void* context, void*, uint32_t result), void* context, void*) noexcept; - void __stdcall WINRT_IMPL_SetThreadpoolWait(winrt::impl::ptp_wait wait, void* handle, void* timeout) noexcept; + int32_t __stdcall WINRT_IMPL_SetThreadpoolWaitEx(winrt::impl::ptp_wait wait, void* handle, void* timeout, void* reserved) noexcept; void __stdcall WINRT_IMPL_CloseThreadpoolWait(winrt::impl::ptp_wait wait) noexcept; winrt::impl::ptp_io __stdcall WINRT_IMPL_CreateThreadpoolIo(void* object, void(__stdcall *callback)(void*, void* context, void* overlapped, uint32_t result, std::size_t bytes, void*) noexcept, void* context, void*) noexcept; void __stdcall WINRT_IMPL_StartThreadpoolIo(winrt::impl::ptp_io io) noexcept; @@ -147,10 +147,10 @@ WINRT_IMPL_LINK(WaitForSingleObject, 8) WINRT_IMPL_LINK(TrySubmitThreadpoolCallback, 12) WINRT_IMPL_LINK(CreateThreadpoolTimer, 12) -WINRT_IMPL_LINK(SetThreadpoolTimer, 16) +WINRT_IMPL_LINK(SetThreadpoolTimerEx, 20) WINRT_IMPL_LINK(CloseThreadpoolTimer, 4) WINRT_IMPL_LINK(CreateThreadpoolWait, 12) -WINRT_IMPL_LINK(SetThreadpoolWait, 12) +WINRT_IMPL_LINK(SetThreadpoolWaitEx, 16) WINRT_IMPL_LINK(CloseThreadpoolWait, 4) WINRT_IMPL_LINK(CreateThreadpoolIo, 16) WINRT_IMPL_LINK(StartThreadpoolIo, 4) diff --git a/strings/base_includes.h b/strings/base_includes.h index 37c7e1baa..eb50d30a2 100644 --- a/strings/base_includes.h +++ b/strings/base_includes.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include diff --git a/test/test/async_propagate_cancel.cpp b/test/test/async_propagate_cancel.cpp new file mode 100644 index 000000000..f735f1bf2 --- /dev/null +++ b/test/test/async_propagate_cancel.cpp @@ -0,0 +1,139 @@ +#include "pch.h" + +using namespace winrt; +using namespace Windows::Foundation; + +namespace +{ + // + // Checks that cancellation propagation works. + // + + IAsyncAction Action() + { + // Do an extra co_await before the resume_on_signal + // so that there is a race window where we can try to cancel + // the "co_await resume_on_signal()" before it starts. + co_await resume_background(); + + auto cancel = co_await get_cancellation_token(); + cancel.enable_propagation(); + co_await resume_on_signal(GetCurrentProcess()); // never wakes + REQUIRE(false); + } + + IAsyncActionWithProgress ActionWithProgress() + { + // Do an extra co_await before the resume_on_signal + // so that there is a race window where we can try to cancel + // the "co_await resume_on_signal()" before it starts. + co_await resume_background(); + + auto cancel = co_await get_cancellation_token(); + cancel.enable_propagation(); + co_await resume_on_signal(GetCurrentProcess()); // never wakes + REQUIRE(false); + } + + IAsyncOperation Operation() + { + // Do an extra co_await before the resume_on_signal + // so that there is a race window where we can try to cancel + // the "co_await resume_on_signal()" before it starts. + co_await resume_background(); + + auto cancel = co_await get_cancellation_token(); + cancel.enable_propagation(); + co_await resume_on_signal(GetCurrentProcess()); // never wakes + REQUIRE(false); + co_return 1; + } + + IAsyncOperationWithProgress OperationWithProgress() + { + // Do an extra co_await before the resume_on_signal + // so that there is a race window where we can try to cancel + // the "co_await resume_on_signal()" before it starts. + co_await resume_background(); + + auto cancel = co_await get_cancellation_token(); + cancel.enable_propagation(); + co_await resume_on_signal(GetCurrentProcess()); // never wakes + REQUIRE(false); + co_return 1; + } + + // Checking cancellation propagation for resume_after. + IAsyncAction DelayAction() + { + // Do an extra co_await before the resume_on_signal + // so that there is a race window where we can try to cancel + // the "co_await resume_after()" before it starts. + co_await resume_background(); + + auto cancel = co_await get_cancellation_token(); + cancel.enable_propagation(); + co_await resume_after(std::chrono::hours(1)); // effectively sleep forever + REQUIRE(false); + } + + // Checking cancellation propagation for IAsyncAction. + // We nest "depth" layers deep and then cancel the very + // deeply-nested IAsyncAction. This validates that propagation + // carried all the way down, and also lets us verify (via + // manual debugging) the deep cancellation doesn't cause us to + // blow up the stack on deeply nested cancellation. + IAsyncAction ActionAction(int depth) + { + // Do an extra co_await before the resume_on_signal + // so that there is a race window where we can try to cancel + // the "co_await ActionAction()" before it starts. + co_await resume_background(); + + auto cancel = co_await get_cancellation_token(); + cancel.enable_propagation(); + if (depth > 0) + { + co_await ActionAction(depth - 1); + } + else + { + co_await Action(); + } + REQUIRE(false); + } + + template + void Check(F make) + { + handle completed{ CreateEvent(nullptr, true, false, nullptr) }; + auto async = make(); + REQUIRE(async.Status() == AsyncStatus::Started); + + async.Completed([&](auto&& sender, AsyncStatus status) + { + REQUIRE(async == sender); + REQUIRE(status == AsyncStatus::Canceled); + SetEvent(completed.get()); + }); + + async.Cancel(); + + // Wait indefinitely if a debugger is present, to make it easier to debug this test. + REQUIRE(WaitForSingleObject(completed.get(), IsDebuggerPresent() ? INFINITE : 1000) == WAIT_OBJECT_0); + + REQUIRE(async.Status() == AsyncStatus::Canceled); + REQUIRE(async.ErrorCode() == HRESULT_FROM_WIN32(ERROR_CANCELLED)); + REQUIRE_THROWS_AS(async.GetResults(), hresult_canceled); + } +} + +TEST_CASE("async_propagate_cancel") +{ + Check(Action); + Check(ActionWithProgress); + Check(Operation); + Check(OperationWithProgress); + Check(DelayAction); + Check([] { return ActionAction(10); }); +} diff --git a/test/test/test.vcxproj b/test/test/test.vcxproj index 3dd7004ce..8d5daeeb1 100644 --- a/test/test/test.vcxproj +++ b/test/test/test.vcxproj @@ -294,6 +294,7 @@ + From 64f3a57fc3b84819d04334362fb80d44d970610e Mon Sep 17 00:00:00 2001 From: Raymond Chen Date: Thu, 17 Sep 2020 14:25:02 -0700 Subject: [PATCH 2/2] Fix signature of SetThreadpoolTimerEx --- strings/base_coroutine_threadpool.h | 6 +++--- strings/base_extern.h | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/strings/base_coroutine_threadpool.h b/strings/base_coroutine_threadpool.h index ac309b377..b630ef5c8 100644 --- a/strings/base_coroutine_threadpool.h +++ b/strings/base_coroutine_threadpool.h @@ -382,7 +382,7 @@ WINRT_EXPORT namespace winrt m_handle = handle; m_timer.attach(check_pointer(WINRT_IMPL_CreateThreadpoolTimer(callback, this, nullptr))); int64_t relative_count = -m_duration.count(); - WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), &relative_count, 0, 0, nullptr); + WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), &relative_count, 0, 0); state expected = state::idle; if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release)) @@ -403,10 +403,10 @@ WINRT_EXPORT namespace winrt void fire_immediately() noexcept { - if (WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), nullptr, 0, 0, nullptr)) + if (WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), nullptr, 0, 0)) { int64_t now = 0; - WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), &now, 0, 0, nullptr); + WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), &now, 0, 0); } } diff --git a/strings/base_extern.h b/strings/base_extern.h index 859956f77..528e22c00 100644 --- a/strings/base_extern.h +++ b/strings/base_extern.h @@ -63,7 +63,7 @@ extern "C" int32_t __stdcall WINRT_IMPL_TrySubmitThreadpoolCallback(void(__stdcall *callback)(void*, void* context), void* context, void*) noexcept; winrt::impl::ptp_timer __stdcall WINRT_IMPL_CreateThreadpoolTimer(void(__stdcall *callback)(void*, void* context, void*), void* context, void*) noexcept; - int32_t __stdcall WINRT_IMPL_SetThreadpoolTimerEx(winrt::impl::ptp_timer timer, void* time, uint32_t period, uint32_t window, void* reserved) noexcept; + int32_t __stdcall WINRT_IMPL_SetThreadpoolTimerEx(winrt::impl::ptp_timer timer, void* time, uint32_t period, uint32_t window) noexcept; void __stdcall WINRT_IMPL_CloseThreadpoolTimer(winrt::impl::ptp_timer timer) noexcept; winrt::impl::ptp_wait __stdcall WINRT_IMPL_CreateThreadpoolWait(void(__stdcall *callback)(void*, void* context, void*, uint32_t result), void* context, void*) noexcept; int32_t __stdcall WINRT_IMPL_SetThreadpoolWaitEx(winrt::impl::ptp_wait wait, void* handle, void* timeout, void* reserved) noexcept; @@ -147,7 +147,7 @@ WINRT_IMPL_LINK(WaitForSingleObject, 8) WINRT_IMPL_LINK(TrySubmitThreadpoolCallback, 12) WINRT_IMPL_LINK(CreateThreadpoolTimer, 12) -WINRT_IMPL_LINK(SetThreadpoolTimerEx, 20) +WINRT_IMPL_LINK(SetThreadpoolTimerEx, 16) WINRT_IMPL_LINK(CloseThreadpoolTimer, 4) WINRT_IMPL_LINK(CreateThreadpoolWait, 12) WINRT_IMPL_LINK(SetThreadpoolWaitEx, 16)