From 19dd55a3692765ed36579d924a98cd62019880e2 Mon Sep 17 00:00:00 2001 From: Raymond Chen Date: Thu, 4 Mar 2021 22:21:54 -0800 Subject: [PATCH 1/2] Allow IAsync...WithProgress to report intermediate results Clients of IAsync...WithProgress are permitted to call GetResults() before the operation has completed, allowing them to observe partial results. This behavior is necessary because the Progress type must be a value type, but the operation may want to report reference types in their progress reports. The way to do that is to report the reference type as the GetResults(). The coroutine itself can use the `set_result()` method on the progress token to report a result prior to completion, and can update that result as often as desired prior to the completion of the coroutine. The value passed to `co_return` acts as the final result of the coroutine. If `set_result()` is never called, then the intermediate result is an empty result (null for reference types, default-initialized for value types). In practice, the result of the coroutine is usually a reference type, and the common pattern is to create the result object, pass it to `set_result()` at the start of the coroutine, update the result object as the coroutine progresses, and (redundantly) pass the same result object to `co_return` when the coroutine completes. ```cpp IAsyncOperationWithProgress BakeMuffinAsync() { auto progress = co_await get_progress_token(); BakingResult result; result.Status(BakingStatus::WarmingUp); progress.set_result(result); progress(0.0); result.Status(BakingStatus::Baking); progress(0.2); result.Muffin(Muffin()); result.Status(BakingStatus::Success); progress(1.0); co_return result; } ``` Since clients of ...WithProgress are permitted to call GetResults() prior to completion, this introduces a few new wrinkles. 1. We need to use the lock to ensure a client doesn't try to read a partial result at the same time we are writing a new one. 2. The GetResults() method must return a copy of the result. This is true even if the coroutine has completed, because a progress handler might go off and do some asynchronous work, and then come back and ask for the intermediate result. If the GetResults() method moves the results of a completed coroutine, then the async progress handler and the co_await will compete to retrieve the completed results, and somebody will lose. As a result, ...WithProgress coroutines are slightly more expensive than non-progress coroutines due to the extra locking and copying. Fortunately, it's mostly pay-for-play. If a coroutine never calls `set_result()`, and the client never calls `GetResults()` from its progress handler, then there is no new lock contention. The final copy is unavoidable, but fortunately, the result is usually a reference type, so you just pay an extra AddRef/Release pair. --- strings/base_coroutine_foundation.h | 38 +++++++++++++++++++++++++++-- test/old_tests/UnitTests/async.cpp | 12 ++++++--- test/test/async_progress.cpp | 27 ++++++++++++++++++-- test/test/async_result.cpp | 10 +++++++- test/test/async_suspend.cpp | 10 +++++++- test/test/pch.h | 36 +++++++++++++++++++++++++++ test/test_win7/async_result.cpp | 10 +++++++- test/test_win7/async_suspend.cpp | 10 +++++++- test/test_win7/pch.h | 36 +++++++++++++++++++++++++++ 9 files changed, 177 insertions(+), 12 deletions(-) diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index a3cf5b780..59f5c162e 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -338,6 +338,13 @@ namespace winrt::impl m_promise->set_progress(result); } + template + void set_result(T&& value) const + { + static_assert(!std::is_same_v, "Setting preliminary results requires IAsync...WithProgress"); + m_promise->return_value(std::forward(value)); + } + private: Promise* m_promise; @@ -454,9 +461,20 @@ namespace winrt::impl { slim_lock_guard const guard(m_lock); - if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Completed) + if constexpr (std::is_same_v) + { + if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Completed) + { + return static_cast(this)->get_return_value(); + } + } + else { - return static_cast(this)->get_return_value(); + if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Completed || + m_status.load(std::memory_order_relaxed) == AsyncStatus::Started) + { + return static_cast(this)->copy_return_value(); + } } rethrow_if_failed(); @@ -473,6 +491,10 @@ namespace winrt::impl { } + void copy_return_value() const noexcept + { + } + void set_completed() noexcept { async_completed_handler_t handler; @@ -691,6 +713,11 @@ namespace std::experimental return std::move(m_result); } + TResult copy_return_value() noexcept + { + return m_result; + } + void return_value(TResult&& value) noexcept { m_result = std::move(value); @@ -730,13 +757,20 @@ namespace std::experimental return std::move(m_result); } + TResult copy_return_value() noexcept + { + return m_result; + } + void return_value(TResult&& value) noexcept { + winrt::slim_lock_guard const guard(this->m_lock); m_result = std::move(value); } void return_value(TResult const& value) noexcept { + winrt::slim_lock_guard const guard(this->m_lock); m_result = value; } diff --git a/test/old_tests/UnitTests/async.cpp b/test/old_tests/UnitTests/async.cpp index 385f240ea..4452e6db4 100644 --- a/test/old_tests/UnitTests/async.cpp +++ b/test/old_tests/UnitTests/async.cpp @@ -832,7 +832,8 @@ TEST_CASE("async, Cancel_IAsyncActionWithProgress") handle event { CreateEvent(nullptr, false, false, nullptr)}; IAsyncActionWithProgress async = Cancel_IAsyncActionWithProgress(event.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + // It is legal to read results of an incomplete WithProgress. + REQUIRE_NOTHROW(async.GetResults()); async.Cancel(); SetEvent(event.get()); // signal async to run @@ -862,7 +863,8 @@ TEST_CASE("async, Cancel_IAsyncActionWithProgress, 2") handle event { CreateEvent(nullptr, false, false, nullptr)}; IAsyncActionWithProgress async = Cancel_IAsyncActionWithProgress(event.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + // It is legal to read results of an incomplete WithProgress. + REQUIRE_NOTHROW(async.GetResults()); bool completed = false; bool objectMatches = false; @@ -952,7 +954,8 @@ TEST_CASE("async, Cancel_IAsyncOperationWithProgress") handle event { CreateEvent(nullptr, false, false, nullptr)}; IAsyncOperationWithProgress async = Cancel_IAsyncOperationWithProgress(event.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + // It is legal to read results of an incomplete WithProgress. + REQUIRE_NOTHROW(async.GetResults()); async.Cancel(); SetEvent(event.get()); // signal async to run @@ -982,7 +985,8 @@ TEST_CASE("async, Cancel_IAsyncOperationWithProgress, 2") handle event { CreateEvent(nullptr, false, false, nullptr)}; IAsyncOperationWithProgress async = Cancel_IAsyncOperationWithProgress(event.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + // It is legal to read results of an incomplete WithProgress. + REQUIRE_NOTHROW(async.GetResults()); bool completed = false; bool objectMatches = false; diff --git a/test/test/async_progress.cpp b/test/test/async_progress.cpp index a1ef121c8..1a7dc04fc 100644 --- a/test/test/async_progress.cpp +++ b/test/test/async_progress.cpp @@ -16,16 +16,17 @@ namespace progress(123); } - IAsyncOperationWithProgress Operation(HANDLE event) + IAsyncOperationWithProgress Operation(HANDLE event) { co_await resume_on_signal(event); // Invoke from a lambda to ensure that operator() is const. [progress = co_await get_progress_token()]() { + progress.set_result(L"working"); progress(123); }(); - co_return 1; + co_return L"done"; } template @@ -35,12 +36,21 @@ namespace handle start{ CreateEvent(nullptr, true, false, nullptr) }; auto async = make(start.get()); + using TResult = decltype(async.GetResults()); + bool progress = false; async.Progress([&](auto&& sender, int value) { progress = true; REQUIRE(async == sender); + REQUIRE(async.Status() == AsyncStatus::Started); + if constexpr (std::is_same_v) + { + REQUIRE(async.GetResults() == L"working"); + // Confirm that reading does not destroy partial results. + REQUIRE(async.GetResults() == L"working"); + } REQUIRE(value == 123); }); @@ -50,6 +60,19 @@ namespace REQUIRE(progress); REQUIRE(async.Status() == AsyncStatus::Completed); REQUIRE(async.ErrorCode() == S_OK); + + // Confirm that you can read results from a completed WithProgress multiple times. + // (We must allow this to avoid race conditions where a progress callback + // does async work and then tries to read an intermediate result, only to + // discover that the operation has already completed.) + if constexpr (std::is_same_v) + { + REQUIRE(async.GetResults() == L"done"); + } + else + { + REQUIRE_NOTHROW(async.GetResults()); + } } template diff --git a/test/test/async_result.cpp b/test/test/async_result.cpp index f02c10ce4..9c1accccb 100644 --- a/test/test/async_result.cpp +++ b/test/test/async_result.cpp @@ -40,7 +40,15 @@ namespace handle completed{ CreateEvent(nullptr, true, false, nullptr) }; auto async = make(start.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + if constexpr (has_async_progress) + { + // You're allowed to peek at partial results of IAsyncXxxWithProgress. + REQUIRE_NOTHROW(async.GetResults()); + } + else + { + REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + } async.Completed([&](auto&& sender, AsyncStatus status) { diff --git a/test/test/async_suspend.cpp b/test/test/async_suspend.cpp index 3c0fd791e..f7beee21d 100644 --- a/test/test/async_suspend.cpp +++ b/test/test/async_suspend.cpp @@ -49,7 +49,15 @@ namespace handle completed{ CreateEvent(nullptr, true, false, nullptr) }; auto async = make(start.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + if constexpr (has_async_progress) + { + // You're allowed to peek at partial results of IAsyncXxxWithProgress. + REQUIRE_NOTHROW(async.GetResults()); + } + else + { + REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + } async.Completed([&](auto&& sender, AsyncStatus status) { diff --git a/test/test/pch.h b/test/test/pch.h index ea74230cc..916e8fe8f 100644 --- a/test/test/pch.h +++ b/test/test/pch.h @@ -10,3 +10,39 @@ #include "catch.hpp" using namespace std::literals; + +// Extracts return and progress types from IAsyncXxx. + +template +struct async_traits; + +template<> +struct async_traits +{ + using progress_type = void; +}; + +template +struct async_traits> +{ + using progress_type = P; +}; + +template +struct async_traits> +{ + using progress_type = void; +}; + +template +struct async_traits> +{ + using progress_type = P; +}; + +template +using async_return_type = decltype(std::declval().GetResults()); +template +using async_progress_type = typename async_traits>::progress_type; +template +inline constexpr bool has_async_progress = !std::is_same_v>::progress_type>; diff --git a/test/test_win7/async_result.cpp b/test/test_win7/async_result.cpp index f02c10ce4..9c1accccb 100644 --- a/test/test_win7/async_result.cpp +++ b/test/test_win7/async_result.cpp @@ -40,7 +40,15 @@ namespace handle completed{ CreateEvent(nullptr, true, false, nullptr) }; auto async = make(start.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + if constexpr (has_async_progress) + { + // You're allowed to peek at partial results of IAsyncXxxWithProgress. + REQUIRE_NOTHROW(async.GetResults()); + } + else + { + REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + } async.Completed([&](auto&& sender, AsyncStatus status) { diff --git a/test/test_win7/async_suspend.cpp b/test/test_win7/async_suspend.cpp index 3c0fd791e..f7beee21d 100644 --- a/test/test_win7/async_suspend.cpp +++ b/test/test_win7/async_suspend.cpp @@ -49,7 +49,15 @@ namespace handle completed{ CreateEvent(nullptr, true, false, nullptr) }; auto async = make(start.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + if constexpr (has_async_progress) + { + // You're allowed to peek at partial results of IAsyncXxxWithProgress. + REQUIRE_NOTHROW(async.GetResults()); + } + else + { + REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + } async.Completed([&](auto&& sender, AsyncStatus status) { diff --git a/test/test_win7/pch.h b/test/test_win7/pch.h index 2f7ee3989..84bb40137 100644 --- a/test/test_win7/pch.h +++ b/test/test_win7/pch.h @@ -7,3 +7,39 @@ #include "catch.hpp" using namespace std::literals; + +// Extracts return and progress types from IAsyncXxx. + +template +struct async_traits; + +template<> +struct async_traits +{ + using progress_type = void; +}; + +template +struct async_traits> +{ + using progress_type = P; +}; + +template +struct async_traits> +{ + using progress_type = void; +}; + +template +struct async_traits> +{ + using progress_type = P; +}; + +template +using async_return_type = decltype(std::declval().GetResults()); +template +using async_progress_type = typename async_traits>::progress_type; +template +inline constexpr bool has_async_progress = !std::is_same_v>::progress_type>; From 4276fafcabdd67efb78c2c46dd4a33ed267d9405 Mon Sep 17 00:00:00 2001 From: Raymond Chen Date: Fri, 5 Mar 2021 12:26:43 -0800 Subject: [PATCH 2/2] Don't call m_status.load() so often Every call to `atomic.load()` goes to memory. Cache the value to remove extra memory accesses. Split more of the implementation of `GetResult()` since there are optimizations available to the ...WithProgress cases, since the `illegal_method_call` case can never happen. --- strings/base_coroutine_foundation.h | 33 ++++++++++++++++------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index 59f5c162e..943e640ed 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -382,13 +382,12 @@ namespace winrt::impl m_completed_assigned = true; - if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Started) + status = m_status.load(std::memory_order_relaxed); + if (status == AsyncStatus::Started) { m_completed = make_agile_delegate(handler); return; } - - status = m_status.load(std::memory_order_relaxed); } if (handler) @@ -421,7 +420,7 @@ namespace winrt::impl try { slim_lock_guard const guard(m_lock); - rethrow_if_failed(); + rethrow_if_failed(m_status.load(std::memory_order_relaxed)); return 0; } catch (...) @@ -461,25 +460,28 @@ namespace winrt::impl { slim_lock_guard const guard(m_lock); + auto status = m_status.load(std::memory_order_relaxed); + if constexpr (std::is_same_v) { - if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Completed) + if (status == AsyncStatus::Completed) { return static_cast(this)->get_return_value(); } + rethrow_if_failed(status); + WINRT_ASSERT(status == AsyncStatus::Started); + throw hresult_illegal_method_call(); } else { - if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Completed || - m_status.load(std::memory_order_relaxed) == AsyncStatus::Started) + if (status == AsyncStatus::Completed || status == AsyncStatus::Started) { return static_cast(this)->copy_return_value(); } + WINRT_ASSERT(status == AsyncStatus::Error || status == AsyncStatus::Canceled); + std::rethrow_exception(m_exception); } - rethrow_if_failed(); - WINRT_ASSERT(m_status.load(std::memory_order_relaxed) == AsyncStatus::Started); - throw hresult_illegal_method_call(); } AsyncInterface get_return_object() const noexcept @@ -503,13 +505,14 @@ namespace winrt::impl { slim_lock_guard const guard(m_lock); - if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Started) + status = m_status.load(std::memory_order_relaxed); + if (status == AsyncStatus::Started) { - m_status.store(AsyncStatus::Completed, std::memory_order_relaxed); + status = AsyncStatus::Completed; + m_status.store(status, std::memory_order_relaxed); } handler = std::move(this->m_completed); - status = this->m_status.load(std::memory_order_relaxed); } if (handler) @@ -632,9 +635,9 @@ namespace winrt::impl protected: - void rethrow_if_failed() const + void rethrow_if_failed(AsyncStatus status) const { - if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Error || m_status.load(std::memory_order_relaxed) == AsyncStatus::Canceled) + if (status == AsyncStatus::Error || status == AsyncStatus::Canceled) { std::rethrow_exception(m_exception); }