Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions strings/base_activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,12 @@ WINRT_EXPORT namespace winrt
impl::get_factory_cache().clear();
}

template <typename Interface>
auto try_create_instance(guid const& clsid, uint32_t context = 0x1 /*CLSCTX_INPROC_SERVER*/, void* outer = nullptr)
{
return try_capture<Interface>(WINRT_IMPL_CoCreateInstance, clsid, outer, context);
}

template <typename Interface>
auto create_instance(guid const& clsid, uint32_t context = 0x1 /*CLSCTX_INPROC_SERVER*/, void* outer = nullptr)
{
Expand Down
27 changes: 27 additions & 0 deletions strings/base_com_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,18 @@ WINRT_EXPORT namespace winrt
*other = m_ptr;
}

template <typename F, typename...Args>
bool try_capture(F function, Args&&...args)
{
return function(args..., guid_of<T>(), put_void()) >= 0;
}

template <typename O, typename M, typename...Args>
bool try_capture(com_ptr<O> const& object, M method, Args&&...args)
{
return (object.get()->*(method))(args..., guid_of<T>(), put_void()) >= 0;
}

template <typename F, typename...Args>
void capture(F function, Args&&...args)
{
Expand Down Expand Up @@ -204,6 +216,21 @@ WINRT_EXPORT namespace winrt
type* m_ptr{};
};

template <typename T, typename F, typename...Args>
impl::com_ref<T> try_capture(F function, Args&& ...args)
{
void* result{};
function(args..., guid_of<T>(), &result);
return { result, take_ownership_from_abi };
}

template <typename T, typename O, typename M, typename...Args>
impl::com_ref<T> try_capture(com_ptr<O> const& object, M method, Args&& ...args)
{
void* result{};
(object.get()->*(method))(args..., guid_of<T>(), &result);
return { result, take_ownership_from_abi };
}
template <typename T, typename F, typename...Args>
impl::com_ref<T> capture(F function, Args&& ...args)
{
Expand Down
4 changes: 2 additions & 2 deletions strings/base_coroutine_foundation.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace winrt::impl
{
// Note: A blocking wait on the UI thread for an asynchronous operation can cause a deadlock.
// See https://docs.microsoft.com/windows/uwp/cpp-and-winrt-apis/concurrency#block-the-calling-thread
WINRT_ASSERT(!is_sta());
WINRT_ASSERT(!is_sta_thread());
}

template <typename T, typename H>
Expand Down Expand Up @@ -119,7 +119,7 @@ namespace winrt::impl

private:
std::experimental::coroutine_handle<> m_handle;
com_ptr<IContextCallback> m_context = apartment_context();
resume_apartment_context m_context;

void Complete()
{
Expand Down
99 changes: 72 additions & 27 deletions strings/base_coroutine_threadpool.h
Original file line number Diff line number Diff line change
@@ -1,63 +1,108 @@

namespace winrt::impl
{
inline auto submit_threadpool_callback(void(__stdcall* callback)(void*, void* context), void* context)
{
if (!WINRT_IMPL_TrySubmitThreadpoolCallback(callback, context, nullptr))
{
throw_last_error();
}
}

inline void __stdcall resume_background_callback(void*, void* context) noexcept
{
std::experimental::coroutine_handle<>::from_address(context)();
};

inline auto resume_background(std::experimental::coroutine_handle<> handle)
{
if (!WINRT_IMPL_TrySubmitThreadpoolCallback(resume_background_callback, handle.address(), nullptr))
{
throw_last_error();
}
submit_threadpool_callback(resume_background_callback, handle.address());
}

inline bool is_sta() noexcept
inline std::pair<int32_t, int32_t> get_apartment_type() noexcept
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider mentioning the results hold APTTYPE_* and APTTYPEQUALIFIER_* values

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already mentioned in the PR description.

{
int32_t aptType;
int32_t aptTypeQualifier;
return (0 == WINRT_IMPL_CoGetApartmentType(&aptType, &aptTypeQualifier)) && ((aptType == 0 /*APTTYPE_STA*/) || (aptType == 3 /*APTTYPE_MAINSTA*/));
if (0 == WINRT_IMPL_CoGetApartmentType(&aptType, &aptTypeQualifier))
{
return { aptType, aptTypeQualifier };
}
else
{
return { 1 /* APTTYPE_MTA */, 1 /* APTTYPEQUALIFIER_IMPLICIT_MTA */ };
}
}

inline bool requires_apartment_context() noexcept
inline bool is_sta_thread() noexcept
{
int32_t aptType;
int32_t aptTypeQualifier;
return (0 == WINRT_IMPL_CoGetApartmentType(&aptType, &aptTypeQualifier)) && ((aptType == 0 /*APTTYPE_STA*/) || (aptType == 2 /*APTTYPE_NA*/) || (aptType == 3 /*APTTYPE_MAINSTA*/));
auto type = get_apartment_type();
switch (type.first)
{
case 0: /* APTTYPE_STA */
case 3: /* APTTYPE_MAINSTA */
return true;
case 2: /* APTTYPE_NA */
return type.second == 3 /* APTTYPEQUALIFIER_NA_ON_STA */ ||
type.second == 5 /* APTTYPEQUALIFIER_NA_ON_MAINSTA */;
}
return false;
}

inline auto apartment_context()
struct resume_apartment_context
{
return requires_apartment_context() ? capture<IContextCallback>(WINRT_IMPL_CoGetObjectContext) : nullptr;
}
com_ptr<IContextCallback> m_context = try_capture<IContextCallback>(WINRT_IMPL_CoGetObjectContext);
int32_t m_context_type = get_apartment_type().first;
};

inline int32_t __stdcall resume_apartment_callback(com_callback_args* args) noexcept
{
std::experimental::coroutine_handle<>::from_address(args->data)();
return 0;
};

inline auto resume_apartment(com_ptr<IContextCallback> const& context, std::experimental::coroutine_handle<> handle)
inline void resume_apartment_sync(com_ptr<IContextCallback> const& context, std::experimental::coroutine_handle<> handle)
{
com_callback_args args{};
args.data = handle.address();

check_hresult(context->ContextCallback(resume_apartment_callback, &args, guid_of<ICallbackWithNoReentrancyToApplicationSTA>(), 5, nullptr));
}

inline void resume_apartment_on_threadpool(com_ptr<IContextCallback> const& context, std::experimental::coroutine_handle<> handle)
{
if (context)
struct threadpool_resume
{
com_callback_args args{};
args.data = handle.address();
threadpool_resume(com_ptr<IContextCallback> const& context, std::experimental::coroutine_handle<> handle) :
m_context(context), m_handle(handle) { }
com_ptr<IContextCallback> m_context;
std::experimental::coroutine_handle<> m_handle;
};
auto state = std::make_unique<threadpool_resume>(context, handle);
submit_threadpool_callback([](void*, void* p)
{
std::unique_ptr<threadpool_resume> state{ static_cast<threadpool_resume*>(p) };
resume_apartment_sync(state->m_context, state->m_handle);
}, state.get());
state.release();
}

check_hresult(context->ContextCallback(resume_apartment_callback, &args, guid_of<ICallbackWithNoReentrancyToApplicationSTA>(), 5, nullptr));
inline auto resume_apartment(resume_apartment_context const& context, std::experimental::coroutine_handle<> handle)
{
if ((context.m_context == nullptr) || (context.m_context == try_capture<IContextCallback>(WINRT_IMPL_CoGetObjectContext)))
{
handle();
}
else if (context.m_context_type == 1 /* APTTYPE_MTA */)
{
resume_background(handle);
}
else if ((context.m_context_type == 2 /* APTTYPE_NTA */) && is_sta_thread())
{
resume_apartment_on_threadpool(context.m_context, handle);
}
else
{
if (requires_apartment_context())
{
resume_background(handle);
}
else
{
handle();
}
resume_apartment_sync(context.m_context, handle);
}
}

Expand Down Expand Up @@ -294,7 +339,7 @@ WINRT_EXPORT namespace winrt
impl::resume_apartment(context, handle);
}

com_ptr<impl::IContextCallback> context = impl::apartment_context();
impl::resume_apartment_context context;
};

[[nodiscard]] inline auto resume_after(Windows::Foundation::TimeSpan duration) noexcept
Expand Down
49 changes: 49 additions & 0 deletions test/old_tests/UnitTests/apartment_context.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include "pch.h"
#include "catch.hpp"
#include <ctxtcall.h>

using namespace winrt;
using namespace Windows::Foundation;
using namespace Windows::System;

namespace
{
Expand All @@ -12,9 +14,56 @@ namespace

co_await context;
}

template<typename TLambda>
void InvokeInContext(IContextCallback* context, TLambda&& lambda)
{
ComCallData data;
data.pUserDefined = &lambda;
check_hresult(context->ContextCallback([](ComCallData* data) -> HRESULT
{
auto& lambda = *reinterpret_cast<TLambda*>(data->pUserDefined);
lambda();
return S_OK;
}, &data, IID_ICallbackWithNoReentrancyToApplicationSTA, 5, nullptr));
}

auto get_winrt_apartment_context_for_com_context(com_ptr<::IContextCallback> const& com_context)
{
std::optional<decltype(apartment_context())> context;
InvokeInContext(com_context.get(), [&] {
context = apartment_context();
});
return context.value();
}

bool is_nta_on_mta()
{
APTTYPE type;
APTTYPEQUALIFIER qualifier;
check_hresult(CoGetApartmentType(&type, &qualifier));
return (type == APTTYPE_NA) && (qualifier == APTTYPEQUALIFIER_NA_ON_MTA || qualifier == APTTYPEQUALIFIER_NA_ON_IMPLICIT_MTA);
}

IAsyncAction TestNeutralApartmentContext()
{
auto controller = DispatcherQueueController::CreateOnDedicatedThread();
co_await resume_foreground(controller.DispatcherQueue());

// Entering neutral apartment from STA should resume on explicit background thread.
auto nta = get_winrt_apartment_context_for_com_context(capture<::IContextCallback>(CoGetDefaultContext, APTTYPE_NA));
co_await nta;

REQUIRE(is_nta_on_mta());
}
}

TEST_CASE("apartment_context coverage")
{
Async().get();
}

TEST_CASE("apartment_context nta")
{
TestNeutralApartmentContext().get();
}
24 changes: 24 additions & 0 deletions test/old_tests/UnitTests/capture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,27 @@ TEST_CASE("capture")
REQUIRE_THROWS_AS(capture<IDispatch>(a, &ICapture::CreateMemberCapture, 0), hresult_no_interface);
REQUIRE_THROWS_AS(d.capture(a, &ICapture::CreateMemberCapture, 0), hresult_no_interface);
}

TEST_CASE("try_capture")
{
// Identical to the "capture" test above, just with different
// error handling.
com_ptr<ICapture> a = try_capture<ICapture>(CreateCapture, 10);
REQUIRE(a->GetValue() == 10);
a = nullptr;
REQUIRE(a.try_capture(CreateCapture, 20));
REQUIRE(a->GetValue() == 20);

auto b = try_capture<ICapture>(a, &ICapture::CreateMemberCapture, 30);
REQUIRE(b->GetValue() == 30);
b = nullptr;
REQUIRE(b.try_capture(a, &ICapture::CreateMemberCapture, 40));
REQUIRE(b->GetValue() == 40);

com_ptr<IDispatch> d;

REQUIRE(!try_capture<IDispatch>(CreateCapture, 0));
REQUIRE(!d.try_capture(CreateCapture, 0));
REQUIRE(!try_capture<IDispatch>(a, &ICapture::CreateMemberCapture, 0));
REQUIRE(!d.try_capture(a, &ICapture::CreateMemberCapture, 0));
}
9 changes: 9 additions & 0 deletions test/old_tests/UnitTests/create_instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,12 @@ TEST_CASE("create_instance")
com_ptr<IFileOpenDialog> dialog = create_instance<IFileOpenDialog>(guid_of<FileOpenDialog>());
REQUIRE(dialog);
}

TEST_CASE("try_create_instance")
{
com_ptr<IFileOpenDialog> dialog = try_create_instance<IFileOpenDialog>(guid_of<FileOpenDialog>());
REQUIRE(dialog);

dialog = try_create_instance<IFileOpenDialog>(CLSID_NULL);
REQUIRE(!dialog);
}
22 changes: 15 additions & 7 deletions test/test/await_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ using namespace Windows::System;

namespace
{
bool is_sta()
{
APTTYPE type;
APTTYPEQUALIFIER qualifier;
check_hresult(CoGetApartmentType(&type, &qualifier));
return (type == APTTYPE_STA) || (type == APTTYPE_MAINSTA);
}

static handle signal{ CreateEventW(nullptr, false, false, nullptr) };

IAsyncAction OtherForegroundAsync()
Expand All @@ -29,9 +37,9 @@ namespace

IAsyncAction ForegroundAsync(DispatcherQueue dispatcher)
{
REQUIRE(!impl::is_sta());
REQUIRE(!is_sta());
co_await resume_foreground(dispatcher);
REQUIRE(impl::is_sta());
REQUIRE(is_sta());

// This exercises one STA thread waiting on another thus one context callback
// completing on another.
Expand All @@ -48,9 +56,9 @@ namespace

fire_and_forget SignalFromForeground(DispatcherQueue dispatcher)
{
REQUIRE(!impl::is_sta());
REQUIRE(!is_sta());
co_await resume_foreground(dispatcher);
REQUIRE(impl::is_sta());
REQUIRE(is_sta());

// Previously, this signal was never raised because the foreground thread
// was always blocked waiting for ContextCallback to return.
Expand All @@ -61,19 +69,19 @@ namespace
{
// Switch to a background (MTA) thread.
co_await resume_background();
REQUIRE(!impl::is_sta());
REQUIRE(!is_sta());

// This exercises one MTA thread waiting on another and just completing
// directly without the overhead of a context switch.
co_await OtherBackgroundAsync();
REQUIRE(!impl::is_sta());
REQUIRE(!is_sta());

// Wait for a coroutine that completes on a foreground (STA) thread.
co_await ForegroundAsync(dispatcher);

// Resumption should automatically switch to a background (MTA) thread
// without blocking the Completed handler (which would in turn block the foreground thread).
REQUIRE(!impl::is_sta());
REQUIRE(!is_sta());

// Attempt to signal from the foreground thread under the assumption
// that the foreground thread is not blocked.
Expand Down