diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index d7e325b9e..8e4993a40 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -100,12 +100,13 @@ namespace winrt::impl struct disconnect_aware_handler { - disconnect_aware_handler(coroutine_handle<> handle) noexcept - : m_handle(handle) { } + disconnect_aware_handler(coroutine_handle<> handle, int32_t* failure) noexcept + : m_handle(handle), m_failure(failure) { } disconnect_aware_handler(disconnect_aware_handler&& other) noexcept : m_context(std::move(other.m_context)) - , m_handle(std::exchange(other.m_handle, {})) { } + , m_handle(std::exchange(other.m_handle, {})) + , m_failure(other.m_failure) { } ~disconnect_aware_handler() { @@ -120,10 +121,11 @@ namespace winrt::impl private: resume_apartment_context m_context; coroutine_handle<> m_handle; + int32_t* m_failure; void Complete() { - resume_apartment(m_context, std::exchange(m_handle, {})); + resume_apartment(m_context, std::exchange(m_handle, {}), m_failure); } }; @@ -135,6 +137,7 @@ namespace winrt::impl Async const& async; Windows::Foundation::AsyncStatus status = Windows::Foundation::AsyncStatus::Started; + int32_t failure = 0; void enable_cancellation(cancellable_promise* promise) { @@ -152,7 +155,7 @@ namespace winrt::impl void await_suspend(coroutine_handle<> handle) { auto extend_lifetime = async; - async.Completed([this, handler = disconnect_aware_handler{ handle }](auto&&, auto operation_status) mutable + async.Completed([this, handler = disconnect_aware_handler{ handle, &failure }](auto&&, auto operation_status) mutable { status = operation_status; handler(); @@ -161,6 +164,7 @@ namespace winrt::impl auto await_resume() const { + check_hresult(failure); check_status_canceled(status); return async.GetResults(); } diff --git a/strings/base_coroutine_threadpool.h b/strings/base_coroutine_threadpool.h index 4ea37bf75..ae77512d7 100644 --- a/strings/base_coroutine_threadpool.h +++ b/strings/base_coroutine_threadpool.h @@ -77,36 +77,43 @@ namespace winrt::impl return 0; }; - inline void resume_apartment_sync(com_ptr const& context, coroutine_handle<> handle) + inline void resume_apartment_sync(com_ptr const& context, coroutine_handle<> handle, int32_t* failure) { com_callback_args args{}; args.data = handle.address(); - check_hresult(context->ContextCallback(resume_apartment_callback, &args, guid_of(), 5, nullptr)); + auto result = context->ContextCallback(resume_apartment_callback, &args, guid_of(), 5, nullptr); + if (result < 0) + { + // Resume the coroutine on the wrong apartment, but tell it why. + *failure = result; + handle(); + } } struct threadpool_resume { - threadpool_resume(com_ptr const& context, coroutine_handle<> handle) : - m_context(context), m_handle(handle) { } + threadpool_resume(com_ptr const& context, coroutine_handle<> handle, int32_t* failure) : + m_context(context), m_handle(handle), m_failure(failure) { } com_ptr m_context; coroutine_handle<> m_handle; + int32_t* m_failure; }; inline void __stdcall fallback_submit_threadpool_callback(void*, void* p) noexcept { std::unique_ptr state{ static_cast(p) }; - resume_apartment_sync(state->m_context, state->m_handle); + resume_apartment_sync(state->m_context, state->m_handle, state->m_failure); } - inline void resume_apartment_on_threadpool(com_ptr const& context, coroutine_handle<> handle) + inline void resume_apartment_on_threadpool(com_ptr const& context, coroutine_handle<> handle, int32_t* failure) { - auto state = std::make_unique(context, handle); + auto state = std::make_unique(context, handle, failure); submit_threadpool_callback(fallback_submit_threadpool_callback, state.get()); state.release(); } - inline auto resume_apartment(resume_apartment_context const& context, coroutine_handle<> handle) + inline auto resume_apartment(resume_apartment_context const& context, coroutine_handle<> handle, int32_t* failure) { WINRT_ASSERT(context.valid()); if ((context.m_context == nullptr) || (context.m_context == try_capture(WINRT_IMPL_CoGetObjectContext))) @@ -119,11 +126,11 @@ namespace winrt::impl } else if (is_sta_thread()) { - resume_apartment_on_threadpool(context.m_context, handle); + resume_apartment_on_threadpool(context.m_context, handle, failure); } else { - resume_apartment_sync(context.m_context, handle); + resume_apartment_sync(context.m_context, handle, failure); } } @@ -365,23 +372,42 @@ WINRT_EXPORT namespace winrt operator bool() const noexcept { return context.valid(); } bool operator!() const noexcept { return !context.valid(); } + impl::resume_apartment_context context; + }; +} + +namespace winrt::impl +{ + struct apartment_awaiter + { + apartment_context context; // make a copy because resuming may destruct the original + int32_t failure = 0; + bool await_ready() const noexcept { return false; } - void await_resume() const noexcept + void await_resume() const { + check_hresult(failure); } - void await_suspend(impl::coroutine_handle<> handle) const + void await_suspend(impl::coroutine_handle<> handle) { - auto copy = context; // resuming may destruct *this, so use a copy - impl::resume_apartment(copy, handle); + impl::resume_apartment(context.context, handle, &failure); } - - impl::resume_apartment_context context; }; +} + +WINRT_EXPORT namespace winrt +{ +#ifdef WINRT_IMPL_COROUTINES + inline impl::apartment_awaiter operator co_await(apartment_context const& context) + { + return{ context }; + } +#endif [[nodiscard]] inline auto resume_after(Windows::Foundation::TimeSpan duration) noexcept { diff --git a/test/old_tests/UnitTests/apartment_context.cpp b/test/old_tests/UnitTests/apartment_context.cpp index a7b34414e..e7c152fd2 100644 --- a/test/old_tests/UnitTests/apartment_context.cpp +++ b/test/old_tests/UnitTests/apartment_context.cpp @@ -77,6 +77,19 @@ namespace return (type == APTTYPE_NA) && (qualifier == APTTYPEQUALIFIER_NA_ON_MTA || qualifier == APTTYPEQUALIFIER_NA_ON_IMPLICIT_MTA); } + bool is_mta() + { + APTTYPE type; + APTTYPEQUALIFIER qualifier; + check_hresult(CoGetApartmentType(&type, &qualifier)); + return type == APTTYPE_MTA; + } + + bool is_invalid_context_error(HRESULT hr) + { + return (hr == RPC_E_SERVER_DIED_DNE) || (hr == RPC_E_DISCONNECTED); + } + IAsyncAction TestNeutralApartmentContext() { auto controller = DispatcherQueueController::CreateOnDedicatedThread(); @@ -93,7 +106,7 @@ namespace { bool pass = false; - apartment_context original; + const apartment_context original; // Create an STA thread and switch to it. auto controller1 = DispatcherQueueController::CreateOnDedicatedThread(); @@ -144,10 +157,89 @@ namespace co_await controller1.ShutdownQueueAsync(); co_await controller2.ShutdownQueueAsync(); +} + +IAsyncAction TestDisconnectedApartmentContext() +{ + // Create an STA thread and switch to it. + auto controller1 = DispatcherQueueController::CreateOnDedicatedThread(); + co_await resume_foreground(controller1.DispatcherQueue()); + + // Create another STA thread. + auto controller2 = DispatcherQueueController::CreateOnDedicatedThread(); + + // Test returning to a destroyed context after IAsyncXxx completion on STA. + HRESULT error = S_OK; + try + { + co_await[](auto controller1, auto controller2) -> IAsyncAction + { + co_await resume_foreground(controller2.DispatcherQueue()); + co_await controller1.ShutdownQueueAsync(); + }(controller1, controller2); } + catch (...) + { + error = winrt::to_hresult(); + REQUIRE(is_mta()); + } + REQUIRE(is_invalid_context_error(error)); + + // Re-create the STA thread for the next test. + controller1 = DispatcherQueueController::CreateOnDedicatedThread(); + co_await resume_foreground(controller1.DispatcherQueue()); + // Save the COM context for the first STA thread. + apartment_context context1; + + // Test returning to a destroyed context after IAsyncXxx completion on MTA. + error = S_OK; + try + { + co_await[](auto controller1) -> IAsyncAction + { + co_await resume_background(); // doesn't work + co_await controller1.ShutdownQueueAsync(); + }(controller1); + } + catch (...) + { + error = winrt::to_hresult(); + REQUIRE(is_mta()); + } + REQUIRE(is_invalid_context_error(error)); + + // Test switching to a destroyed context from an STA. + co_await resume_foreground(controller2.DispatcherQueue()); + + error = S_OK; + try { + co_await context1; + } + catch (...) + { + error = winrt::to_hresult(); + REQUIRE(is_mta()); + } + REQUIRE(is_invalid_context_error(error)); + + // Test switching to a destroyed context from an MTA. + error = S_OK; + try { + co_await context1; + } + catch (...) + { + error = winrt::to_hresult(); + REQUIRE(is_mta()); + } + REQUIRE(is_invalid_context_error(error)); + + // Clean up. + co_await controller2.ShutdownQueueAsync(); } +} TEST_CASE("apartment_context coverage") { Async().get(); @@ -162,3 +254,8 @@ TEST_CASE("apartment_context sta") { TestStaToStaApartmentContext().get(); } + +TEST_CASE("apartment_context disconnected") +{ + TestDisconnectedApartmentContext().get(); +} diff --git a/test/test/disconnected.cpp b/test/test/disconnected.cpp index d02976dc4..41733e2f8 100644 --- a/test/test/disconnected.cpp +++ b/test/test/disconnected.cpp @@ -1,9 +1,11 @@ #include "pch.h" #include +#include using namespace std::literals; using namespace winrt; using namespace Windows::Foundation; +using namespace Windows::System; namespace { @@ -32,6 +34,14 @@ namespace progress(123); co_return 123; } + + bool is_mta() + { + APTTYPE type; + APTTYPEQUALIFIER qualifier; + check_hresult(CoGetApartmentType(&type, &qualifier)); + return type == APTTYPE_MTA; + } } TEST_CASE("disconnected,handler") @@ -127,7 +137,7 @@ TEST_CASE("disconnected,handler") // Custom action to simulate an out-of-process server that crashes before it can complete. struct non_agile_abandoned_action : implements { - non_agile_abandoned_action(void* event_handle) : m_awaited(event_handle) {} + non_agile_abandoned_action(delegate<> disconnect) : m_disconnect(disconnect) {} static fire_and_forget final_release(std::unique_ptr self) { @@ -140,8 +150,7 @@ struct non_agile_abandoned_action : implements m_disconnect; }; namespace @@ -208,7 +217,7 @@ TEST_CASE("disconnected,action") agile_ref action; InvokeInContext(private_context.get(), [&]() { - action = make(signal.get()); + action = make([&]{ SetEvent(signal.get()); }); }); auto result = [](IAsyncAction action) -> IAsyncAction @@ -218,3 +227,50 @@ TEST_CASE("disconnected,action") REQUIRE_THROWS_MATCHES(result.get(), hresult_error, holds_hresult(RPC_E_DISCONNECTED)); } + +TEST_CASE("disconnected,double") +{ + // The double-disconnect case, where the IAsyncAction disconnects, + // and tries to return to the original context, but it too has disconnected! + auto test = []() -> IAsyncAction + { + auto private_context = create_instance(CLSID_ContextSwitcher); + handle signal{ CreateEventW(nullptr, true, false, nullptr) }; + disconnect_on_signal(private_context, signal.get()); + + // Create an STA thread that we will destroy while awaiting. + auto controller = DispatcherQueueController::CreateOnDedicatedThread(); + + agile_ref action; + InvokeInContext(private_context.get(), [&]() + { + action = make([&]() -> fire_and_forget + { + // Get off the DispatcherQueue thread. + co_await resume_background(); + // Destroy the DispatcherQueue, so the co_await has nowhere to return to. + co_await controller.ShutdownQueueAsync(); + // Now set the event to force the action to disconnect. + SetEvent(signal.get()); + }); + }); + + // Go to our STA thread. + co_await resume_foreground(controller.DispatcherQueue()); + + HRESULT hr = S_OK; + try + { + co_await action.get(); + } + catch (...) + { + hr = to_hresult(); + REQUIRE(is_mta()); + } + REQUIRE(FAILED(hr)); + + }(); + + test.get(); +}