diff --git a/strings/base_activation.h b/strings/base_activation.h index 7f35e1a79..685d9f5c0 100644 --- a/strings/base_activation.h +++ b/strings/base_activation.h @@ -487,6 +487,12 @@ WINRT_EXPORT namespace winrt impl::get_factory_cache().clear(); } + template + auto try_create_instance(guid const& clsid, uint32_t context = 0x1 /*CLSCTX_INPROC_SERVER*/, void* outer = nullptr) + { + return try_capture(WINRT_IMPL_CoCreateInstance, clsid, outer, context); + } + template auto create_instance(guid const& clsid, uint32_t context = 0x1 /*CLSCTX_INPROC_SERVER*/, void* outer = nullptr) { diff --git a/strings/base_com_ptr.h b/strings/base_com_ptr.h index 46867a73a..3ae9d88cd 100644 --- a/strings/base_com_ptr.h +++ b/strings/base_com_ptr.h @@ -153,6 +153,18 @@ WINRT_EXPORT namespace winrt *other = m_ptr; } + template + bool try_capture(F function, Args&&...args) + { + return function(args..., guid_of(), put_void()) >= 0; + } + + template + bool try_capture(com_ptr const& object, M method, Args&&...args) + { + return (object.get()->*(method))(args..., guid_of(), put_void()) >= 0; + } + template void capture(F function, Args&&...args) { @@ -204,6 +216,21 @@ WINRT_EXPORT namespace winrt type* m_ptr{}; }; + template + impl::com_ref try_capture(F function, Args&& ...args) + { + void* result{}; + function(args..., guid_of(), &result); + return { result, take_ownership_from_abi }; + } + + template + impl::com_ref try_capture(com_ptr const& object, M method, Args&& ...args) + { + void* result{}; + (object.get()->*(method))(args..., guid_of(), &result); + return { result, take_ownership_from_abi }; + } template impl::com_ref capture(F function, Args&& ...args) { diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index fe4975b96..311522a6c 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -35,7 +35,7 @@ namespace winrt::impl { // Note: A blocking wait on the UI thread for an asynchronous operation can cause a deadlock. // See https://docs.microsoft.com/windows/uwp/cpp-and-winrt-apis/concurrency#block-the-calling-thread - WINRT_ASSERT(!is_sta()); + WINRT_ASSERT(!is_sta_thread()); } template @@ -119,7 +119,7 @@ namespace winrt::impl private: std::experimental::coroutine_handle<> m_handle; - com_ptr m_context = apartment_context(); + resume_apartment_context m_context; void Complete() { diff --git a/strings/base_coroutine_threadpool.h b/strings/base_coroutine_threadpool.h index 254840dc2..ae26677ef 100644 --- a/strings/base_coroutine_threadpool.h +++ b/strings/base_coroutine_threadpool.h @@ -1,6 +1,14 @@ namespace winrt::impl { + inline auto submit_threadpool_callback(void(__stdcall* callback)(void*, void* context), void* context) + { + if (!WINRT_IMPL_TrySubmitThreadpoolCallback(callback, context, nullptr)) + { + throw_last_error(); + } + } + inline void __stdcall resume_background_callback(void*, void* context) noexcept { std::experimental::coroutine_handle<>::from_address(context)(); @@ -8,30 +16,43 @@ namespace winrt::impl inline auto resume_background(std::experimental::coroutine_handle<> handle) { - if (!WINRT_IMPL_TrySubmitThreadpoolCallback(resume_background_callback, handle.address(), nullptr)) - { - throw_last_error(); - } + submit_threadpool_callback(resume_background_callback, handle.address()); } - inline bool is_sta() noexcept + inline std::pair get_apartment_type() noexcept { int32_t aptType; int32_t aptTypeQualifier; - return (0 == WINRT_IMPL_CoGetApartmentType(&aptType, &aptTypeQualifier)) && ((aptType == 0 /*APTTYPE_STA*/) || (aptType == 3 /*APTTYPE_MAINSTA*/)); + if (0 == WINRT_IMPL_CoGetApartmentType(&aptType, &aptTypeQualifier)) + { + return { aptType, aptTypeQualifier }; + } + else + { + return { 1 /* APTTYPE_MTA */, 1 /* APTTYPEQUALIFIER_IMPLICIT_MTA */ }; + } } - inline bool requires_apartment_context() noexcept + inline bool is_sta_thread() noexcept { - int32_t aptType; - int32_t aptTypeQualifier; - return (0 == WINRT_IMPL_CoGetApartmentType(&aptType, &aptTypeQualifier)) && ((aptType == 0 /*APTTYPE_STA*/) || (aptType == 2 /*APTTYPE_NA*/) || (aptType == 3 /*APTTYPE_MAINSTA*/)); + auto type = get_apartment_type(); + switch (type.first) + { + case 0: /* APTTYPE_STA */ + case 3: /* APTTYPE_MAINSTA */ + return true; + case 2: /* APTTYPE_NA */ + return type.second == 3 /* APTTYPEQUALIFIER_NA_ON_STA */ || + type.second == 5 /* APTTYPEQUALIFIER_NA_ON_MAINSTA */; + } + return false; } - inline auto apartment_context() + struct resume_apartment_context { - return requires_apartment_context() ? capture(WINRT_IMPL_CoGetObjectContext) : nullptr; - } + com_ptr m_context = try_capture(WINRT_IMPL_CoGetObjectContext); + int32_t m_context_type = get_apartment_type().first; + }; inline int32_t __stdcall resume_apartment_callback(com_callback_args* args) noexcept { @@ -39,25 +60,49 @@ namespace winrt::impl return 0; }; - inline auto resume_apartment(com_ptr const& context, std::experimental::coroutine_handle<> handle) + inline void resume_apartment_sync(com_ptr const& context, std::experimental::coroutine_handle<> handle) + { + com_callback_args args{}; + args.data = handle.address(); + + check_hresult(context->ContextCallback(resume_apartment_callback, &args, guid_of(), 5, nullptr)); + } + + inline void resume_apartment_on_threadpool(com_ptr const& context, std::experimental::coroutine_handle<> handle) { - if (context) + struct threadpool_resume { - com_callback_args args{}; - args.data = handle.address(); + threadpool_resume(com_ptr const& context, std::experimental::coroutine_handle<> handle) : + m_context(context), m_handle(handle) { } + com_ptr m_context; + std::experimental::coroutine_handle<> m_handle; + }; + auto state = std::make_unique(context, handle); + submit_threadpool_callback([](void*, void* p) + { + std::unique_ptr state{ static_cast(p) }; + resume_apartment_sync(state->m_context, state->m_handle); + }, state.get()); + state.release(); + } - check_hresult(context->ContextCallback(resume_apartment_callback, &args, guid_of(), 5, nullptr)); + inline auto resume_apartment(resume_apartment_context const& context, std::experimental::coroutine_handle<> handle) + { + if ((context.m_context == nullptr) || (context.m_context == try_capture(WINRT_IMPL_CoGetObjectContext))) + { + handle(); + } + else if (context.m_context_type == 1 /* APTTYPE_MTA */) + { + resume_background(handle); + } + else if ((context.m_context_type == 2 /* APTTYPE_NTA */) && is_sta_thread()) + { + resume_apartment_on_threadpool(context.m_context, handle); } else { - if (requires_apartment_context()) - { - resume_background(handle); - } - else - { - handle(); - } + resume_apartment_sync(context.m_context, handle); } } @@ -294,7 +339,7 @@ WINRT_EXPORT namespace winrt impl::resume_apartment(context, handle); } - com_ptr context = impl::apartment_context(); + impl::resume_apartment_context context; }; [[nodiscard]] inline auto resume_after(Windows::Foundation::TimeSpan duration) noexcept diff --git a/test/old_tests/UnitTests/apartment_context.cpp b/test/old_tests/UnitTests/apartment_context.cpp index eab146973..ae3511698 100644 --- a/test/old_tests/UnitTests/apartment_context.cpp +++ b/test/old_tests/UnitTests/apartment_context.cpp @@ -1,8 +1,10 @@ #include "pch.h" #include "catch.hpp" +#include using namespace winrt; using namespace Windows::Foundation; +using namespace Windows::System; namespace { @@ -12,9 +14,56 @@ namespace co_await context; } + + template + void InvokeInContext(IContextCallback* context, TLambda&& lambda) + { + ComCallData data; + data.pUserDefined = λ + check_hresult(context->ContextCallback([](ComCallData* data) -> HRESULT + { + auto& lambda = *reinterpret_cast(data->pUserDefined); + lambda(); + return S_OK; + }, &data, IID_ICallbackWithNoReentrancyToApplicationSTA, 5, nullptr)); + } + + auto get_winrt_apartment_context_for_com_context(com_ptr<::IContextCallback> const& com_context) + { + std::optional context; + InvokeInContext(com_context.get(), [&] { + context = apartment_context(); + }); + return context.value(); + } + + bool is_nta_on_mta() + { + APTTYPE type; + APTTYPEQUALIFIER qualifier; + check_hresult(CoGetApartmentType(&type, &qualifier)); + return (type == APTTYPE_NA) && (qualifier == APTTYPEQUALIFIER_NA_ON_MTA || qualifier == APTTYPEQUALIFIER_NA_ON_IMPLICIT_MTA); + } + + IAsyncAction TestNeutralApartmentContext() + { + auto controller = DispatcherQueueController::CreateOnDedicatedThread(); + co_await resume_foreground(controller.DispatcherQueue()); + + // Entering neutral apartment from STA should resume on explicit background thread. + auto nta = get_winrt_apartment_context_for_com_context(capture<::IContextCallback>(CoGetDefaultContext, APTTYPE_NA)); + co_await nta; + + REQUIRE(is_nta_on_mta()); + } } TEST_CASE("apartment_context coverage") { Async().get(); } + +TEST_CASE("apartment_context nta") +{ + TestNeutralApartmentContext().get(); +} diff --git a/test/old_tests/UnitTests/capture.cpp b/test/old_tests/UnitTests/capture.cpp index e88a29056..d1c4e41de 100644 --- a/test/old_tests/UnitTests/capture.cpp +++ b/test/old_tests/UnitTests/capture.cpp @@ -58,3 +58,27 @@ TEST_CASE("capture") REQUIRE_THROWS_AS(capture(a, &ICapture::CreateMemberCapture, 0), hresult_no_interface); REQUIRE_THROWS_AS(d.capture(a, &ICapture::CreateMemberCapture, 0), hresult_no_interface); } + +TEST_CASE("try_capture") +{ + // Identical to the "capture" test above, just with different + // error handling. + com_ptr a = try_capture(CreateCapture, 10); + REQUIRE(a->GetValue() == 10); + a = nullptr; + REQUIRE(a.try_capture(CreateCapture, 20)); + REQUIRE(a->GetValue() == 20); + + auto b = try_capture(a, &ICapture::CreateMemberCapture, 30); + REQUIRE(b->GetValue() == 30); + b = nullptr; + REQUIRE(b.try_capture(a, &ICapture::CreateMemberCapture, 40)); + REQUIRE(b->GetValue() == 40); + + com_ptr d; + + REQUIRE(!try_capture(CreateCapture, 0)); + REQUIRE(!d.try_capture(CreateCapture, 0)); + REQUIRE(!try_capture(a, &ICapture::CreateMemberCapture, 0)); + REQUIRE(!d.try_capture(a, &ICapture::CreateMemberCapture, 0)); +} diff --git a/test/old_tests/UnitTests/create_instance.cpp b/test/old_tests/UnitTests/create_instance.cpp index 38f1bf031..fa1c2ee53 100644 --- a/test/old_tests/UnitTests/create_instance.cpp +++ b/test/old_tests/UnitTests/create_instance.cpp @@ -9,3 +9,12 @@ TEST_CASE("create_instance") com_ptr dialog = create_instance(guid_of()); REQUIRE(dialog); } + +TEST_CASE("try_create_instance") +{ + com_ptr dialog = try_create_instance(guid_of()); + REQUIRE(dialog); + + dialog = try_create_instance(CLSID_NULL); + REQUIRE(!dialog); +} diff --git a/test/test/await_adapter.cpp b/test/test/await_adapter.cpp index 65018d33b..16575699b 100644 --- a/test/test/await_adapter.cpp +++ b/test/test/await_adapter.cpp @@ -8,6 +8,14 @@ using namespace Windows::System; namespace { + bool is_sta() + { + APTTYPE type; + APTTYPEQUALIFIER qualifier; + check_hresult(CoGetApartmentType(&type, &qualifier)); + return (type == APTTYPE_STA) || (type == APTTYPE_MAINSTA); + } + static handle signal{ CreateEventW(nullptr, false, false, nullptr) }; IAsyncAction OtherForegroundAsync() @@ -29,9 +37,9 @@ namespace IAsyncAction ForegroundAsync(DispatcherQueue dispatcher) { - REQUIRE(!impl::is_sta()); + REQUIRE(!is_sta()); co_await resume_foreground(dispatcher); - REQUIRE(impl::is_sta()); + REQUIRE(is_sta()); // This exercises one STA thread waiting on another thus one context callback // completing on another. @@ -48,9 +56,9 @@ namespace fire_and_forget SignalFromForeground(DispatcherQueue dispatcher) { - REQUIRE(!impl::is_sta()); + REQUIRE(!is_sta()); co_await resume_foreground(dispatcher); - REQUIRE(impl::is_sta()); + REQUIRE(is_sta()); // Previously, this signal was never raised because the foreground thread // was always blocked waiting for ContextCallback to return. @@ -61,19 +69,19 @@ namespace { // Switch to a background (MTA) thread. co_await resume_background(); - REQUIRE(!impl::is_sta()); + REQUIRE(!is_sta()); // This exercises one MTA thread waiting on another and just completing // directly without the overhead of a context switch. co_await OtherBackgroundAsync(); - REQUIRE(!impl::is_sta()); + REQUIRE(!is_sta()); // Wait for a coroutine that completes on a foreground (STA) thread. co_await ForegroundAsync(dispatcher); // Resumption should automatically switch to a background (MTA) thread // without blocking the Completed handler (which would in turn block the foreground thread). - REQUIRE(!impl::is_sta()); + REQUIRE(!is_sta()); // Attempt to signal from the foreground thread under the assumption // that the foreground thread is not blocked.