Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions strings/base_coroutine_foundation.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,13 @@ namespace winrt::impl
m_promise->set_progress(result);
}

template<typename T>
void set_result(T&& value) const
{
static_assert(!std::is_same_v<Progress, void>, "Setting preliminary results requires IAsync...WithProgress");
m_promise->return_value(std::forward<T>(value));
}

private:

Promise* m_promise;
Expand Down Expand Up @@ -375,13 +382,12 @@ namespace winrt::impl

m_completed_assigned = true;

if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Started)
status = m_status.load(std::memory_order_relaxed);
if (status == AsyncStatus::Started)
{
m_completed = make_agile_delegate(handler);
return;
}

status = m_status.load(std::memory_order_relaxed);
}

if (handler)
Expand Down Expand Up @@ -414,7 +420,7 @@ namespace winrt::impl
try
{
slim_lock_guard const guard(m_lock);
rethrow_if_failed();
rethrow_if_failed(m_status.load(std::memory_order_relaxed));
return 0;
}
catch (...)
Expand Down Expand Up @@ -454,14 +460,28 @@ namespace winrt::impl
{
slim_lock_guard const guard(m_lock);

if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Completed)
auto status = m_status.load(std::memory_order_relaxed);

if constexpr (std::is_same_v<TProgress, void>)
{
return static_cast<Derived*>(this)->get_return_value();
if (status == AsyncStatus::Completed)
{
return static_cast<Derived*>(this)->get_return_value();
}
rethrow_if_failed(status);
WINRT_ASSERT(status == AsyncStatus::Started);
throw hresult_illegal_method_call();
}
else
{
if (status == AsyncStatus::Completed || status == AsyncStatus::Started)
{
return static_cast<Derived*>(this)->copy_return_value();
}
WINRT_ASSERT(status == AsyncStatus::Error || status == AsyncStatus::Canceled);
std::rethrow_exception(m_exception);
}

rethrow_if_failed();
WINRT_ASSERT(m_status.load(std::memory_order_relaxed) == AsyncStatus::Started);
throw hresult_illegal_method_call();
}

AsyncInterface get_return_object() const noexcept
Expand All @@ -473,6 +493,10 @@ namespace winrt::impl
{
}

void copy_return_value() const noexcept
{
}

void set_completed() noexcept
{
async_completed_handler_t<AsyncInterface> handler;
Expand All @@ -481,13 +505,14 @@ namespace winrt::impl
{
slim_lock_guard const guard(m_lock);

if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Started)
status = m_status.load(std::memory_order_relaxed);
if (status == AsyncStatus::Started)
{
m_status.store(AsyncStatus::Completed, std::memory_order_relaxed);
status = AsyncStatus::Completed;
m_status.store(status, std::memory_order_relaxed);
}

handler = std::move(this->m_completed);
status = this->m_status.load(std::memory_order_relaxed);
}

if (handler)
Expand Down Expand Up @@ -610,9 +635,9 @@ namespace winrt::impl

protected:

void rethrow_if_failed() const
void rethrow_if_failed(AsyncStatus status) const
{
if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Error || m_status.load(std::memory_order_relaxed) == AsyncStatus::Canceled)
if (status == AsyncStatus::Error || status == AsyncStatus::Canceled)
{
std::rethrow_exception(m_exception);
}
Expand Down Expand Up @@ -691,6 +716,11 @@ namespace std::experimental
return std::move(m_result);
}

TResult copy_return_value() noexcept
{
return m_result;
}

void return_value(TResult&& value) noexcept
{
m_result = std::move(value);
Expand Down Expand Up @@ -730,13 +760,20 @@ namespace std::experimental
return std::move(m_result);
}

TResult copy_return_value() noexcept
{
return m_result;
}

void return_value(TResult&& value) noexcept
{
winrt::slim_lock_guard const guard(this->m_lock);
m_result = std::move(value);
}

void return_value(TResult const& value) noexcept
{
winrt::slim_lock_guard const guard(this->m_lock);
m_result = value;
}

Expand Down
12 changes: 8 additions & 4 deletions test/old_tests/UnitTests/async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,8 @@ TEST_CASE("async, Cancel_IAsyncActionWithProgress")
handle event { CreateEvent(nullptr, false, false, nullptr)};
IAsyncActionWithProgress<double> async = Cancel_IAsyncActionWithProgress(event.get());
REQUIRE(async.Status() == AsyncStatus::Started);
REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call);
// It is legal to read results of an incomplete WithProgress.
REQUIRE_NOTHROW(async.GetResults());

async.Cancel();
SetEvent(event.get()); // signal async to run
Expand Down Expand Up @@ -862,7 +863,8 @@ TEST_CASE("async, Cancel_IAsyncActionWithProgress, 2")
handle event { CreateEvent(nullptr, false, false, nullptr)};
IAsyncActionWithProgress<double> async = Cancel_IAsyncActionWithProgress(event.get());
REQUIRE(async.Status() == AsyncStatus::Started);
REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call);
// It is legal to read results of an incomplete WithProgress.
REQUIRE_NOTHROW(async.GetResults());

bool completed = false;
bool objectMatches = false;
Expand Down Expand Up @@ -952,7 +954,8 @@ TEST_CASE("async, Cancel_IAsyncOperationWithProgress")
handle event { CreateEvent(nullptr, false, false, nullptr)};
IAsyncOperationWithProgress<uint64_t, uint64_t> async = Cancel_IAsyncOperationWithProgress(event.get());
REQUIRE(async.Status() == AsyncStatus::Started);
REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call);
// It is legal to read results of an incomplete WithProgress.
REQUIRE_NOTHROW(async.GetResults());

async.Cancel();
SetEvent(event.get()); // signal async to run
Expand Down Expand Up @@ -982,7 +985,8 @@ TEST_CASE("async, Cancel_IAsyncOperationWithProgress, 2")
handle event { CreateEvent(nullptr, false, false, nullptr)};
IAsyncOperationWithProgress<uint64_t, uint64_t> async = Cancel_IAsyncOperationWithProgress(event.get());
REQUIRE(async.Status() == AsyncStatus::Started);
REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call);
// It is legal to read results of an incomplete WithProgress.
REQUIRE_NOTHROW(async.GetResults());

bool completed = false;
bool objectMatches = false;
Expand Down
27 changes: 25 additions & 2 deletions test/test/async_progress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@ namespace
progress(123);
}

IAsyncOperationWithProgress<int, int> Operation(HANDLE event)
IAsyncOperationWithProgress<hstring, int> Operation(HANDLE event)
{
co_await resume_on_signal(event);

// Invoke from a lambda to ensure that operator() is const.
[progress = co_await get_progress_token()]()
{
progress.set_result(L"working");
progress(123);
}();
co_return 1;
co_return L"done";
}

template <typename F>
Expand All @@ -35,12 +36,21 @@ namespace
handle start{ CreateEvent(nullptr, true, false, nullptr) };

auto async = make(start.get());
using TResult = decltype(async.GetResults());

bool progress = false;

async.Progress([&](auto&& sender, int value)
{
progress = true;
REQUIRE(async == sender);
REQUIRE(async.Status() == AsyncStatus::Started);
if constexpr (std::is_same_v<TResult, hstring>)
{
REQUIRE(async.GetResults() == L"working");
// Confirm that reading does not destroy partial results.
REQUIRE(async.GetResults() == L"working");
}
REQUIRE(value == 123);
});

Expand All @@ -50,6 +60,19 @@ namespace
REQUIRE(progress);
REQUIRE(async.Status() == AsyncStatus::Completed);
REQUIRE(async.ErrorCode() == S_OK);

// Confirm that you can read results from a completed WithProgress multiple times.
// (We must allow this to avoid race conditions where a progress callback
// does async work and then tries to read an intermediate result, only to
// discover that the operation has already completed.)
if constexpr (std::is_same_v<TResult, hstring>)
{
REQUIRE(async.GetResults() == L"done");
}
else
{
REQUIRE_NOTHROW(async.GetResults());
}
}

template <typename F>
Expand Down
10 changes: 9 additions & 1 deletion test/test/async_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,15 @@ namespace
handle completed{ CreateEvent(nullptr, true, false, nullptr) };
auto async = make(start.get());
REQUIRE(async.Status() == AsyncStatus::Started);
REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call);
if constexpr (has_async_progress<decltype(async)>)
{
// You're allowed to peek at partial results of IAsyncXxxWithProgress.
REQUIRE_NOTHROW(async.GetResults());
}
else
{
REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call);
}

async.Completed([&](auto&& sender, AsyncStatus status)
{
Expand Down
10 changes: 9 additions & 1 deletion test/test/async_suspend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,15 @@ namespace
handle completed{ CreateEvent(nullptr, true, false, nullptr) };
auto async = make(start.get());
REQUIRE(async.Status() == AsyncStatus::Started);
REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call);
if constexpr (has_async_progress<decltype(async)>)
{
// You're allowed to peek at partial results of IAsyncXxxWithProgress.
REQUIRE_NOTHROW(async.GetResults());
}
else
{
REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call);
}

async.Completed([&](auto&& sender, AsyncStatus status)
{
Expand Down
36 changes: 36 additions & 0 deletions test/test/pch.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,39 @@
#include "catch.hpp"

using namespace std::literals;

// Extracts return and progress types from IAsyncXxx.

template<typename T>
struct async_traits;

template<>
struct async_traits<winrt::Windows::Foundation::IAsyncAction>
{
using progress_type = void;
};

template<typename P>
struct async_traits<winrt::Windows::Foundation::IAsyncActionWithProgress<P>>
{
using progress_type = P;
};

template<typename R>
struct async_traits<winrt::Windows::Foundation::IAsyncOperation<R>>
{
using progress_type = void;
};

template<typename R, typename P>
struct async_traits<winrt::Windows::Foundation::IAsyncOperationWithProgress<R, P>>
{
using progress_type = P;
};

template<typename T>
using async_return_type = decltype(std::declval<T>().GetResults());
template<typename T>
using async_progress_type = typename async_traits<std::decay_t<T>>::progress_type;
template<typename T>
inline constexpr bool has_async_progress = !std::is_same_v<void, async_traits<std::decay_t<T>>::progress_type>;
10 changes: 9 additions & 1 deletion test/test_win7/async_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,15 @@ namespace
handle completed{ CreateEvent(nullptr, true, false, nullptr) };
auto async = make(start.get());
REQUIRE(async.Status() == AsyncStatus::Started);
REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call);
if constexpr (has_async_progress<decltype(async)>)
{
// You're allowed to peek at partial results of IAsyncXxxWithProgress.
REQUIRE_NOTHROW(async.GetResults());
}
else
{
REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call);
}

async.Completed([&](auto&& sender, AsyncStatus status)
{
Expand Down
10 changes: 9 additions & 1 deletion test/test_win7/async_suspend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,15 @@ namespace
handle completed{ CreateEvent(nullptr, true, false, nullptr) };
auto async = make(start.get());
REQUIRE(async.Status() == AsyncStatus::Started);
REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call);
if constexpr (has_async_progress<decltype(async)>)
{
// You're allowed to peek at partial results of IAsyncXxxWithProgress.
REQUIRE_NOTHROW(async.GetResults());
}
else
{
REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call);
}

async.Completed([&](auto&& sender, AsyncStatus status)
{
Expand Down
36 changes: 36 additions & 0 deletions test/test_win7/pch.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,39 @@
#include "catch.hpp"

using namespace std::literals;

// Extracts return and progress types from IAsyncXxx.

template<typename T>
struct async_traits;

template<>
struct async_traits<winrt::Windows::Foundation::IAsyncAction>
{
using progress_type = void;
};

template<typename P>
struct async_traits<winrt::Windows::Foundation::IAsyncActionWithProgress<P>>
{
using progress_type = P;
};

template<typename R>
struct async_traits<winrt::Windows::Foundation::IAsyncOperation<R>>
{
using progress_type = void;
};

template<typename R, typename P>
struct async_traits<winrt::Windows::Foundation::IAsyncOperationWithProgress<R, P>>
{
using progress_type = P;
};

template<typename T>
using async_return_type = decltype(std::declval<T>().GetResults());
template<typename T>
using async_progress_type = typename async_traits<std::decay_t<T>>::progress_type;
template<typename T>
inline constexpr bool has_async_progress = !std::is_same_v<void, async_traits<std::decay_t<T>>::progress_type>;