Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions strings/base_coroutine_foundation.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ namespace winrt::impl

struct disconnect_aware_handler
{
disconnect_aware_handler(coroutine_handle<> handle) noexcept
: m_handle(handle) { }
disconnect_aware_handler(coroutine_handle<> handle, int32_t* failure) noexcept
: m_handle(handle), m_failure(failure) { }

disconnect_aware_handler(disconnect_aware_handler&& other) noexcept
: m_context(std::move(other.m_context))
, m_handle(std::exchange(other.m_handle, {})) { }
, m_handle(std::exchange(other.m_handle, {}))
, m_failure(other.m_failure) { }

~disconnect_aware_handler()
{
Expand All @@ -120,10 +121,11 @@ namespace winrt::impl
private:
resume_apartment_context m_context;
coroutine_handle<> m_handle;
int32_t* m_failure;

void Complete()
{
resume_apartment(m_context, std::exchange(m_handle, {}));
resume_apartment(m_context, std::exchange(m_handle, {}), m_failure);
}
};

Expand All @@ -135,6 +137,7 @@ namespace winrt::impl

Async const& async;
Windows::Foundation::AsyncStatus status = Windows::Foundation::AsyncStatus::Started;
int32_t failure = 0;

void enable_cancellation(cancellable_promise* promise)
{
Expand All @@ -152,7 +155,7 @@ namespace winrt::impl
void await_suspend(coroutine_handle<> handle)
{
auto extend_lifetime = async;
async.Completed([this, handler = disconnect_aware_handler{ handle }](auto&&, auto operation_status) mutable
async.Completed([this, handler = disconnect_aware_handler{ handle, &failure }](auto&&, auto operation_status) mutable
{
status = operation_status;
handler();
Expand All @@ -161,6 +164,7 @@ namespace winrt::impl

auto await_resume() const
{
check_hresult(failure);
check_status_canceled(status);
return async.GetResults();
}
Expand Down
58 changes: 42 additions & 16 deletions strings/base_coroutine_threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,36 +77,43 @@ namespace winrt::impl
return 0;
};

inline void resume_apartment_sync(com_ptr<IContextCallback> const& context, coroutine_handle<> handle)
inline void resume_apartment_sync(com_ptr<IContextCallback> const& context, coroutine_handle<> handle, int32_t* failure)
{
com_callback_args args{};
args.data = handle.address();

check_hresult(context->ContextCallback(resume_apartment_callback, &args, guid_of<ICallbackWithNoReentrancyToApplicationSTA>(), 5, nullptr));
auto result = context->ContextCallback(resume_apartment_callback, &args, guid_of<ICallbackWithNoReentrancyToApplicationSTA>(), 5, nullptr);
if (result < 0)
{
// Resume the coroutine on the wrong apartment, but tell it why.
*failure = result;
handle();
}
}

struct threadpool_resume
{
threadpool_resume(com_ptr<IContextCallback> const& context, coroutine_handle<> handle) :
m_context(context), m_handle(handle) { }
threadpool_resume(com_ptr<IContextCallback> const& context, coroutine_handle<> handle, int32_t* failure) :
m_context(context), m_handle(handle), m_failure(failure) { }
com_ptr<IContextCallback> m_context;
coroutine_handle<> m_handle;
int32_t* m_failure;
};

inline void __stdcall fallback_submit_threadpool_callback(void*, void* p) noexcept
{
std::unique_ptr<threadpool_resume> state{ static_cast<threadpool_resume*>(p) };
resume_apartment_sync(state->m_context, state->m_handle);
resume_apartment_sync(state->m_context, state->m_handle, state->m_failure);
}

inline void resume_apartment_on_threadpool(com_ptr<IContextCallback> const& context, coroutine_handle<> handle)
inline void resume_apartment_on_threadpool(com_ptr<IContextCallback> const& context, coroutine_handle<> handle, int32_t* failure)
{
auto state = std::make_unique<threadpool_resume>(context, handle);
auto state = std::make_unique<threadpool_resume>(context, handle, failure);
submit_threadpool_callback(fallback_submit_threadpool_callback, state.get());
state.release();
}

inline auto resume_apartment(resume_apartment_context const& context, coroutine_handle<> handle)
inline auto resume_apartment(resume_apartment_context const& context, coroutine_handle<> handle, int32_t* failure)
{
WINRT_ASSERT(context.valid());
if ((context.m_context == nullptr) || (context.m_context == try_capture<IContextCallback>(WINRT_IMPL_CoGetObjectContext)))
Expand All @@ -119,11 +126,11 @@ namespace winrt::impl
}
else if (is_sta_thread())
{
resume_apartment_on_threadpool(context.m_context, handle);
resume_apartment_on_threadpool(context.m_context, handle, failure);
}
else
{
resume_apartment_sync(context.m_context, handle);
resume_apartment_sync(context.m_context, handle, failure);
}
}

Expand Down Expand Up @@ -365,23 +372,42 @@ WINRT_EXPORT namespace winrt
operator bool() const noexcept { return context.valid(); }
bool operator!() const noexcept { return !context.valid(); }

impl::resume_apartment_context context;
};
}

namespace winrt::impl
{
struct apartment_awaiter
{
apartment_context context; // make a copy because resuming may destruct the original
int32_t failure = 0;

bool await_ready() const noexcept
{
return false;
}

void await_resume() const noexcept
void await_resume() const
{
check_hresult(failure);
}

void await_suspend(impl::coroutine_handle<> handle) const
void await_suspend(impl::coroutine_handle<> handle)
{
auto copy = context; // resuming may destruct *this, so use a copy
impl::resume_apartment(copy, handle);
impl::resume_apartment(context.context, handle, &failure);
}

impl::resume_apartment_context context;
};
}

WINRT_EXPORT namespace winrt
{
#ifdef WINRT_IMPL_COROUTINES
inline impl::apartment_awaiter operator co_await(apartment_context const& context)
{
return{ context };
}
#endif

[[nodiscard]] inline auto resume_after(Windows::Foundation::TimeSpan duration) noexcept
{
Expand Down
99 changes: 98 additions & 1 deletion test/old_tests/UnitTests/apartment_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ namespace
return (type == APTTYPE_NA) && (qualifier == APTTYPEQUALIFIER_NA_ON_MTA || qualifier == APTTYPEQUALIFIER_NA_ON_IMPLICIT_MTA);
}

bool is_mta()
{
APTTYPE type;
APTTYPEQUALIFIER qualifier;
check_hresult(CoGetApartmentType(&type, &qualifier));
return type == APTTYPE_MTA;
}

bool is_invalid_context_error(HRESULT hr)
{
return (hr == RPC_E_SERVER_DIED_DNE) || (hr == RPC_E_DISCONNECTED);
}

IAsyncAction TestNeutralApartmentContext()
{
auto controller = DispatcherQueueController::CreateOnDedicatedThread();
Expand All @@ -93,7 +106,7 @@ namespace
{
bool pass = false;

apartment_context original;
const apartment_context original;

// Create an STA thread and switch to it.
auto controller1 = DispatcherQueueController::CreateOnDedicatedThread();
Expand Down Expand Up @@ -144,10 +157,89 @@ namespace

co_await controller1.ShutdownQueueAsync();
co_await controller2.ShutdownQueueAsync();
}

IAsyncAction TestDisconnectedApartmentContext()
{
// Create an STA thread and switch to it.
auto controller1 = DispatcherQueueController::CreateOnDedicatedThread();
co_await resume_foreground(controller1.DispatcherQueue());

// Create another STA thread.
auto controller2 = DispatcherQueueController::CreateOnDedicatedThread();

// Test returning to a destroyed context after IAsyncXxx completion on STA.
HRESULT error = S_OK;
try
{
co_await[](auto controller1, auto controller2) -> IAsyncAction
{
co_await resume_foreground(controller2.DispatcherQueue());
co_await controller1.ShutdownQueueAsync();
}(controller1, controller2);
}
catch (...)
{
error = winrt::to_hresult();
REQUIRE(is_mta());
}
REQUIRE(is_invalid_context_error(error));

// Re-create the STA thread for the next test.
controller1 = DispatcherQueueController::CreateOnDedicatedThread();
co_await resume_foreground(controller1.DispatcherQueue());

// Save the COM context for the first STA thread.
apartment_context context1;

// Test returning to a destroyed context after IAsyncXxx completion on MTA.
error = S_OK;
try
{
co_await[](auto controller1) -> IAsyncAction
{
co_await resume_background(); // doesn't work
co_await controller1.ShutdownQueueAsync();
}(controller1);
}
catch (...)
{
error = winrt::to_hresult();
REQUIRE(is_mta());
}
REQUIRE(is_invalid_context_error(error));

// Test switching to a destroyed context from an STA.
co_await resume_foreground(controller2.DispatcherQueue());

error = S_OK;
try {
co_await context1;
}
catch (...)
{
error = winrt::to_hresult();
REQUIRE(is_mta());
}
REQUIRE(is_invalid_context_error(error));

// Test switching to a destroyed context from an MTA.
error = S_OK;
try {
co_await context1;
}
catch (...)
{
error = winrt::to_hresult();
REQUIRE(is_mta());
}
REQUIRE(is_invalid_context_error(error));

// Clean up.
co_await controller2.ShutdownQueueAsync();
}

}
TEST_CASE("apartment_context coverage")
{
Async().get();
Expand All @@ -162,3 +254,8 @@ TEST_CASE("apartment_context sta")
{
TestStaToStaApartmentContext().get();
}

TEST_CASE("apartment_context disconnected")
{
TestDisconnectedApartmentContext().get();
}
Loading