From f5896c5711aa7831c227c3c74b6c25024efd0e64 Mon Sep 17 00:00:00 2001 From: Raymond Chen Date: Sun, 26 Sep 2021 20:11:51 -0700 Subject: [PATCH 1/3] Properly report failure to switch to target apartment If we were unable to switch to a target apartment (either by an explicit `co_await apartment_context` or implicitly at the completion of ` co_await IAsyncInfo`), we now report the exception in the `co_await`'ing coroutine, so it can be caught and handled. Note that when this error occurs, we are stuck in limbo. We left the original apartment for the MTA, and now we can't get from the MTA to the target apartment. We have no choice but to raise the exception on the MTA, even though the coroutine may not have expected to be running there. ```cpp IAsyncAction Trapped() { Object o; // some object with a destructor co_await destroyed_ui_context; // throws an exception } ``` The destructor for `Object` will run on the MTA, even though the `Trapped` coroutine never runs on the MTA under normal conditions. This is making the best of a bad situation. We need to raise the exception into the coroutine so that the coroutine can enter a failure state and allow the coroutine chain to proceed. Since the object `o` already had to be prepared for destruction on the original apartment or the switched-to apartment (in case an exception occurs in either apartment), we know that the object cannot have apartment affinity anyway. So we're probably okay to destruct it on the MTA. The old code raised the exception on whatever thread happened to notice that something bad happened. Sometimes that's okay, for example if you do a `co_await destroyed_ui_context` from an MTA thread, in which case the exception is thrown into the coroutine and can be caught, or left uncaught and send the coroutine into an error state. However, all of the other cases weren't quite so happy. To avoid deadlocks, the apartment switch is routed through the threadpool, and that means that the failure is raised on a threadpool thread. There's nobody around to catch that exception, so the process fails fast. We solve the problem by having an atomic hresult in the awaiter, which is initialized to 0 (`S_OK`) and transitions to error when we are about to resume the coroutine on the wrong apartment, so that the error can be propagated during `await_resume()`, which happens in the context of the awaiting coroutine and therefore can be caught by the coroutine (or sent to `unhandled_exception`). The hresult in which we record the failure needs to be atomic to avoid data races in case multiple threads detect failure simultaneously. However, the access to it can be relaxed, because it is set and consumed by the same thread. The current implementation pulls a sneaky trick: To avoid having to create an `operator co_await` for the `apartment_context` object, we continue to let it be its own awaiter. This means that if the apartment switch fails, we record the failure in the `apartment_context` itself. We are making the assumption that once an apartment context goes bad, it is unrecoverably broken. --- strings/base_coroutine_foundation.h | 14 ++- strings/base_coroutine_threadpool.h | 48 ++++++--- .../old_tests/UnitTests/apartment_context.cpp | 97 +++++++++++++++++++ test/test/disconnected.cpp | 66 ++++++++++++- 4 files changed, 202 insertions(+), 23 deletions(-) diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index d7e325b9e..61b4f58fa 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -100,12 +100,13 @@ namespace winrt::impl struct disconnect_aware_handler { - disconnect_aware_handler(coroutine_handle<> handle) noexcept - : m_handle(handle) { } + disconnect_aware_handler(coroutine_handle<> handle, resumption_failure* failure) noexcept + : m_handle(handle), m_failure(failure) { } disconnect_aware_handler(disconnect_aware_handler&& other) noexcept : m_context(std::move(other.m_context)) - , m_handle(std::exchange(other.m_handle, {})) { } + , m_handle(std::exchange(other.m_handle, {})) + , m_failure(other.m_failure) { } ~disconnect_aware_handler() { @@ -120,10 +121,11 @@ namespace winrt::impl private: resume_apartment_context m_context; coroutine_handle<> m_handle; + resumption_failure* m_failure; void Complete() { - resume_apartment(m_context, std::exchange(m_handle, {})); + resume_apartment(m_context, std::exchange(m_handle, {}), m_failure); } }; @@ -135,6 +137,7 @@ namespace winrt::impl Async const& async; Windows::Foundation::AsyncStatus status = Windows::Foundation::AsyncStatus::Started; + resumption_failure failure; void enable_cancellation(cancellable_promise* promise) { @@ -152,7 +155,7 @@ namespace winrt::impl void await_suspend(coroutine_handle<> handle) { auto extend_lifetime = async; - async.Completed([this, handler = disconnect_aware_handler{ handle }](auto&&, auto operation_status) mutable + async.Completed([this, handler = disconnect_aware_handler{ handle, &failure }](auto&&, auto operation_status) mutable { status = operation_status; handler(); @@ -161,6 +164,7 @@ namespace winrt::impl auto await_resume() const { + failure.check(); check_status_canceled(status); return async.GetResults(); } diff --git a/strings/base_coroutine_threadpool.h b/strings/base_coroutine_threadpool.h index 4ea37bf75..3edf67d0c 100644 --- a/strings/base_coroutine_threadpool.h +++ b/strings/base_coroutine_threadpool.h @@ -77,36 +77,56 @@ namespace winrt::impl return 0; }; - inline void resume_apartment_sync(com_ptr const& context, coroutine_handle<> handle) + struct resumption_failure + { + resumption_failure() = default; + resumption_failure(resumption_failure const&) noexcept {} + resumption_failure& operator=(resumption_failure const&) noexcept { return *this; } + + // Call only on failure, never on success. + void report_failed(int32_t result) noexcept { failure.store(result, std::memory_order_relaxed); } + void check() const { check_hresult(failure.load(std::memory_order_relaxed)); } + + std::atomic failure; + }; + + inline void resume_apartment_sync(com_ptr const& context, coroutine_handle<> handle, resumption_failure* failure) { com_callback_args args{}; args.data = handle.address(); - check_hresult(context->ContextCallback(resume_apartment_callback, &args, guid_of(), 5, nullptr)); + auto result = context->ContextCallback(resume_apartment_callback, &args, guid_of(), 5, nullptr); + if (result < 0) + { + // Resume the coroutine on the wrong apartment, but tell it why. + failure->report_failed(result); + handle(); + } } struct threadpool_resume { - threadpool_resume(com_ptr const& context, coroutine_handle<> handle) : - m_context(context), m_handle(handle) { } + threadpool_resume(com_ptr const& context, coroutine_handle<> handle, resumption_failure* failure) : + m_context(context), m_handle(handle), m_failure(failure) { } com_ptr m_context; coroutine_handle<> m_handle; + resumption_failure* m_failure; }; inline void __stdcall fallback_submit_threadpool_callback(void*, void* p) noexcept { std::unique_ptr state{ static_cast(p) }; - resume_apartment_sync(state->m_context, state->m_handle); + resume_apartment_sync(state->m_context, state->m_handle, state->m_failure); } - inline void resume_apartment_on_threadpool(com_ptr const& context, coroutine_handle<> handle) + inline void resume_apartment_on_threadpool(com_ptr const& context, coroutine_handle<> handle, resumption_failure* failure) { - auto state = std::make_unique(context, handle); + auto state = std::make_unique(context, handle, failure); submit_threadpool_callback(fallback_submit_threadpool_callback, state.get()); state.release(); } - inline auto resume_apartment(resume_apartment_context const& context, coroutine_handle<> handle) + inline auto resume_apartment(resume_apartment_context const& context, coroutine_handle<> handle, resumption_failure* failure) { WINRT_ASSERT(context.valid()); if ((context.m_context == nullptr) || (context.m_context == try_capture(WINRT_IMPL_CoGetObjectContext))) @@ -119,11 +139,11 @@ namespace winrt::impl } else if (is_sta_thread()) { - resume_apartment_on_threadpool(context.m_context, handle); + resume_apartment_on_threadpool(context.m_context, handle, failure); } else { - resume_apartment_sync(context.m_context, handle); + resume_apartment_sync(context.m_context, handle, failure); } } @@ -370,17 +390,19 @@ WINRT_EXPORT namespace winrt return false; } - void await_resume() const noexcept + void await_resume() const { + failure.check(); } - void await_suspend(impl::coroutine_handle<> handle) const + void await_suspend(impl::coroutine_handle<> handle) { auto copy = context; // resuming may destruct *this, so use a copy - impl::resume_apartment(copy, handle); + impl::resume_apartment(copy, handle, &failure); } impl::resume_apartment_context context; + impl::resumption_failure failure; // assumes that once a context fails, it always fails }; [[nodiscard]] inline auto resume_after(Windows::Foundation::TimeSpan duration) noexcept diff --git a/test/old_tests/UnitTests/apartment_context.cpp b/test/old_tests/UnitTests/apartment_context.cpp index a7b34414e..04e576333 100644 --- a/test/old_tests/UnitTests/apartment_context.cpp +++ b/test/old_tests/UnitTests/apartment_context.cpp @@ -77,6 +77,19 @@ namespace return (type == APTTYPE_NA) && (qualifier == APTTYPEQUALIFIER_NA_ON_MTA || qualifier == APTTYPEQUALIFIER_NA_ON_IMPLICIT_MTA); } + bool is_mta() + { + APTTYPE type; + APTTYPEQUALIFIER qualifier; + check_hresult(CoGetApartmentType(&type, &qualifier)); + return type == APTTYPE_MTA; + } + + bool is_invalid_context_error(HRESULT hr) + { + return (hr == RPC_E_SERVER_DIED_DNE) || (hr == RPC_E_DISCONNECTED); + } + IAsyncAction TestNeutralApartmentContext() { auto controller = DispatcherQueueController::CreateOnDedicatedThread(); @@ -144,10 +157,89 @@ namespace co_await controller1.ShutdownQueueAsync(); co_await controller2.ShutdownQueueAsync(); +} + +IAsyncAction TestDisconnectedApartmentContext() +{ + // Create an STA thread and switch to it. + auto controller1 = DispatcherQueueController::CreateOnDedicatedThread(); + co_await resume_foreground(controller1.DispatcherQueue()); + + // Create another STA thread. + auto controller2 = DispatcherQueueController::CreateOnDedicatedThread(); + + // Test returning to a destroyed context after IAsyncXxx completion on STA. + HRESULT error = S_OK; + try + { + co_await[](auto controller1, auto controller2) -> IAsyncAction + { + co_await resume_foreground(controller2.DispatcherQueue()); + co_await controller1.ShutdownQueueAsync(); + }(controller1, controller2); } + catch (...) + { + error = winrt::to_hresult(); + REQUIRE(is_mta()); + } + REQUIRE(is_invalid_context_error(error)); + + // Re-create the STA thread for the next test. + controller1 = DispatcherQueueController::CreateOnDedicatedThread(); + co_await resume_foreground(controller1.DispatcherQueue()); + // Save the COM context for the first STA thread. + apartment_context context1; + + // Test returning to a destroyed context after IAsyncXxx completion on MTA. + error = S_OK; + try + { + co_await[](auto controller1) -> IAsyncAction + { + co_await resume_background(); // doesn't work + co_await controller1.ShutdownQueueAsync(); + }(controller1); + } + catch (...) + { + error = winrt::to_hresult(); + REQUIRE(is_mta()); + } + REQUIRE(is_invalid_context_error(error)); + + // Test switching to a destroyed context from an STA. + co_await resume_foreground(controller2.DispatcherQueue()); + + error = S_OK; + try { + co_await context1; + } + catch (...) + { + error = winrt::to_hresult(); + REQUIRE(is_mta()); + } + REQUIRE(is_invalid_context_error(error)); + + // Test switching to a destroyed context from an MTA. + error = S_OK; + try { + co_await context1; + } + catch (...) + { + error = winrt::to_hresult(); + REQUIRE(is_mta()); + } + REQUIRE(is_invalid_context_error(error)); + + // Clean up. + co_await controller2.ShutdownQueueAsync(); } +} TEST_CASE("apartment_context coverage") { Async().get(); @@ -162,3 +254,8 @@ TEST_CASE("apartment_context sta") { TestStaToStaApartmentContext().get(); } + +TEST_CASE("apartment_context disconnected") +{ + TestDisconnectedApartmentContext().get(); +} diff --git a/test/test/disconnected.cpp b/test/test/disconnected.cpp index d02976dc4..41733e2f8 100644 --- a/test/test/disconnected.cpp +++ b/test/test/disconnected.cpp @@ -1,9 +1,11 @@ #include "pch.h" #include +#include using namespace std::literals; using namespace winrt; using namespace Windows::Foundation; +using namespace Windows::System; namespace { @@ -32,6 +34,14 @@ namespace progress(123); co_return 123; } + + bool is_mta() + { + APTTYPE type; + APTTYPEQUALIFIER qualifier; + check_hresult(CoGetApartmentType(&type, &qualifier)); + return type == APTTYPE_MTA; + } } TEST_CASE("disconnected,handler") @@ -127,7 +137,7 @@ TEST_CASE("disconnected,handler") // Custom action to simulate an out-of-process server that crashes before it can complete. struct non_agile_abandoned_action : implements { - non_agile_abandoned_action(void* event_handle) : m_awaited(event_handle) {} + non_agile_abandoned_action(delegate<> disconnect) : m_disconnect(disconnect) {} static fire_and_forget final_release(std::unique_ptr self) { @@ -140,8 +150,7 @@ struct non_agile_abandoned_action : implements m_disconnect; }; namespace @@ -208,7 +217,7 @@ TEST_CASE("disconnected,action") agile_ref action; InvokeInContext(private_context.get(), [&]() { - action = make(signal.get()); + action = make([&]{ SetEvent(signal.get()); }); }); auto result = [](IAsyncAction action) -> IAsyncAction @@ -218,3 +227,50 @@ TEST_CASE("disconnected,action") REQUIRE_THROWS_MATCHES(result.get(), hresult_error, holds_hresult(RPC_E_DISCONNECTED)); } + +TEST_CASE("disconnected,double") +{ + // The double-disconnect case, where the IAsyncAction disconnects, + // and tries to return to the original context, but it too has disconnected! + auto test = []() -> IAsyncAction + { + auto private_context = create_instance(CLSID_ContextSwitcher); + handle signal{ CreateEventW(nullptr, true, false, nullptr) }; + disconnect_on_signal(private_context, signal.get()); + + // Create an STA thread that we will destroy while awaiting. + auto controller = DispatcherQueueController::CreateOnDedicatedThread(); + + agile_ref action; + InvokeInContext(private_context.get(), [&]() + { + action = make([&]() -> fire_and_forget + { + // Get off the DispatcherQueue thread. + co_await resume_background(); + // Destroy the DispatcherQueue, so the co_await has nowhere to return to. + co_await controller.ShutdownQueueAsync(); + // Now set the event to force the action to disconnect. + SetEvent(signal.get()); + }); + }); + + // Go to our STA thread. + co_await resume_foreground(controller.DispatcherQueue()); + + HRESULT hr = S_OK; + try + { + co_await action.get(); + } + catch (...) + { + hr = to_hresult(); + REQUIRE(is_mta()); + } + REQUIRE(FAILED(hr)); + + }(); + + test.get(); +} From 7d25848b71c6b9e9fc572ac5e0ebe6d5ed09960b Mon Sep 17 00:00:00 2001 From: Raymond Chen Date: Mon, 27 Sep 2021 06:45:55 -0700 Subject: [PATCH 2/3] Allow await a const apartment_context This also removes the need for an atomic `int32_t`, and removes the assumption that context failure is permanent. --- strings/base_coroutine_foundation.h | 8 +-- strings/base_coroutine_threadpool.h | 54 ++++++++++--------- .../old_tests/UnitTests/apartment_context.cpp | 2 +- 3 files changed, 34 insertions(+), 30 deletions(-) diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index 61b4f58fa..717d09033 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -100,7 +100,7 @@ namespace winrt::impl struct disconnect_aware_handler { - disconnect_aware_handler(coroutine_handle<> handle, resumption_failure* failure) noexcept + disconnect_aware_handler(coroutine_handle<> handle, int32_t* failure) noexcept : m_handle(handle), m_failure(failure) { } disconnect_aware_handler(disconnect_aware_handler&& other) noexcept @@ -121,7 +121,7 @@ namespace winrt::impl private: resume_apartment_context m_context; coroutine_handle<> m_handle; - resumption_failure* m_failure; + int32_t* m_failure = 0; void Complete() { @@ -137,7 +137,7 @@ namespace winrt::impl Async const& async; Windows::Foundation::AsyncStatus status = Windows::Foundation::AsyncStatus::Started; - resumption_failure failure; + int32_t failure; void enable_cancellation(cancellable_promise* promise) { @@ -164,7 +164,7 @@ namespace winrt::impl auto await_resume() const { - failure.check(); + check_hresult(failure); check_status_canceled(status); return async.GetResults(); } diff --git a/strings/base_coroutine_threadpool.h b/strings/base_coroutine_threadpool.h index 3edf67d0c..ae77512d7 100644 --- a/strings/base_coroutine_threadpool.h +++ b/strings/base_coroutine_threadpool.h @@ -77,20 +77,7 @@ namespace winrt::impl return 0; }; - struct resumption_failure - { - resumption_failure() = default; - resumption_failure(resumption_failure const&) noexcept {} - resumption_failure& operator=(resumption_failure const&) noexcept { return *this; } - - // Call only on failure, never on success. - void report_failed(int32_t result) noexcept { failure.store(result, std::memory_order_relaxed); } - void check() const { check_hresult(failure.load(std::memory_order_relaxed)); } - - std::atomic failure; - }; - - inline void resume_apartment_sync(com_ptr const& context, coroutine_handle<> handle, resumption_failure* failure) + inline void resume_apartment_sync(com_ptr const& context, coroutine_handle<> handle, int32_t* failure) { com_callback_args args{}; args.data = handle.address(); @@ -99,18 +86,18 @@ namespace winrt::impl if (result < 0) { // Resume the coroutine on the wrong apartment, but tell it why. - failure->report_failed(result); + *failure = result; handle(); } } struct threadpool_resume { - threadpool_resume(com_ptr const& context, coroutine_handle<> handle, resumption_failure* failure) : + threadpool_resume(com_ptr const& context, coroutine_handle<> handle, int32_t* failure) : m_context(context), m_handle(handle), m_failure(failure) { } com_ptr m_context; coroutine_handle<> m_handle; - resumption_failure* m_failure; + int32_t* m_failure; }; inline void __stdcall fallback_submit_threadpool_callback(void*, void* p) noexcept @@ -119,14 +106,14 @@ namespace winrt::impl resume_apartment_sync(state->m_context, state->m_handle, state->m_failure); } - inline void resume_apartment_on_threadpool(com_ptr const& context, coroutine_handle<> handle, resumption_failure* failure) + inline void resume_apartment_on_threadpool(com_ptr const& context, coroutine_handle<> handle, int32_t* failure) { auto state = std::make_unique(context, handle, failure); submit_threadpool_callback(fallback_submit_threadpool_callback, state.get()); state.release(); } - inline auto resume_apartment(resume_apartment_context const& context, coroutine_handle<> handle, resumption_failure* failure) + inline auto resume_apartment(resume_apartment_context const& context, coroutine_handle<> handle, int32_t* failure) { WINRT_ASSERT(context.valid()); if ((context.m_context == nullptr) || (context.m_context == try_capture(WINRT_IMPL_CoGetObjectContext))) @@ -385,6 +372,17 @@ WINRT_EXPORT namespace winrt operator bool() const noexcept { return context.valid(); } bool operator!() const noexcept { return !context.valid(); } + impl::resume_apartment_context context; + }; +} + +namespace winrt::impl +{ + struct apartment_awaiter + { + apartment_context context; // make a copy because resuming may destruct the original + int32_t failure = 0; + bool await_ready() const noexcept { return false; @@ -392,18 +390,24 @@ WINRT_EXPORT namespace winrt void await_resume() const { - failure.check(); + check_hresult(failure); } void await_suspend(impl::coroutine_handle<> handle) { - auto copy = context; // resuming may destruct *this, so use a copy - impl::resume_apartment(copy, handle, &failure); + impl::resume_apartment(context.context, handle, &failure); } - - impl::resume_apartment_context context; - impl::resumption_failure failure; // assumes that once a context fails, it always fails }; +} + +WINRT_EXPORT namespace winrt +{ +#ifdef WINRT_IMPL_COROUTINES + inline impl::apartment_awaiter operator co_await(apartment_context const& context) + { + return{ context }; + } +#endif [[nodiscard]] inline auto resume_after(Windows::Foundation::TimeSpan duration) noexcept { diff --git a/test/old_tests/UnitTests/apartment_context.cpp b/test/old_tests/UnitTests/apartment_context.cpp index 04e576333..e7c152fd2 100644 --- a/test/old_tests/UnitTests/apartment_context.cpp +++ b/test/old_tests/UnitTests/apartment_context.cpp @@ -106,7 +106,7 @@ namespace { bool pass = false; - apartment_context original; + const apartment_context original; // Create an STA thread and switch to it. auto controller1 = DispatcherQueueController::CreateOnDedicatedThread(); From 0bb17817a9556b76f13010c3264444dbb602a55f Mon Sep 17 00:00:00 2001 From: Raymond Chen Date: Mon, 27 Sep 2021 07:24:19 -0700 Subject: [PATCH 3/3] Initialized the wrong member variable Should be initializing the int32_t, not the pointer --- strings/base_coroutine_foundation.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index 717d09033..8e4993a40 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -121,7 +121,7 @@ namespace winrt::impl private: resume_apartment_context m_context; coroutine_handle<> m_handle; - int32_t* m_failure = 0; + int32_t* m_failure; void Complete() { @@ -137,7 +137,7 @@ namespace winrt::impl Async const& async; Windows::Foundation::AsyncStatus status = Windows::Foundation::AsyncStatus::Started; - int32_t failure; + int32_t failure = 0; void enable_cancellation(cancellable_promise* promise) {