diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index a3cf5b780..943e640ed 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -338,6 +338,13 @@ namespace winrt::impl m_promise->set_progress(result); } + template + void set_result(T&& value) const + { + static_assert(!std::is_same_v, "Setting preliminary results requires IAsync...WithProgress"); + m_promise->return_value(std::forward(value)); + } + private: Promise* m_promise; @@ -375,13 +382,12 @@ namespace winrt::impl m_completed_assigned = true; - if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Started) + status = m_status.load(std::memory_order_relaxed); + if (status == AsyncStatus::Started) { m_completed = make_agile_delegate(handler); return; } - - status = m_status.load(std::memory_order_relaxed); } if (handler) @@ -414,7 +420,7 @@ namespace winrt::impl try { slim_lock_guard const guard(m_lock); - rethrow_if_failed(); + rethrow_if_failed(m_status.load(std::memory_order_relaxed)); return 0; } catch (...) @@ -454,14 +460,28 @@ namespace winrt::impl { slim_lock_guard const guard(m_lock); - if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Completed) + auto status = m_status.load(std::memory_order_relaxed); + + if constexpr (std::is_same_v) { - return static_cast(this)->get_return_value(); + if (status == AsyncStatus::Completed) + { + return static_cast(this)->get_return_value(); + } + rethrow_if_failed(status); + WINRT_ASSERT(status == AsyncStatus::Started); + throw hresult_illegal_method_call(); + } + else + { + if (status == AsyncStatus::Completed || status == AsyncStatus::Started) + { + return static_cast(this)->copy_return_value(); + } + WINRT_ASSERT(status == AsyncStatus::Error || status == AsyncStatus::Canceled); + std::rethrow_exception(m_exception); } - rethrow_if_failed(); - WINRT_ASSERT(m_status.load(std::memory_order_relaxed) == AsyncStatus::Started); - throw hresult_illegal_method_call(); } AsyncInterface get_return_object() const noexcept @@ -473,6 +493,10 @@ namespace winrt::impl { } + void copy_return_value() const noexcept + { + } + void set_completed() noexcept { async_completed_handler_t handler; @@ -481,13 +505,14 @@ namespace winrt::impl { slim_lock_guard const guard(m_lock); - if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Started) + status = m_status.load(std::memory_order_relaxed); + if (status == AsyncStatus::Started) { - m_status.store(AsyncStatus::Completed, std::memory_order_relaxed); + status = AsyncStatus::Completed; + m_status.store(status, std::memory_order_relaxed); } handler = std::move(this->m_completed); - status = this->m_status.load(std::memory_order_relaxed); } if (handler) @@ -610,9 +635,9 @@ namespace winrt::impl protected: - void rethrow_if_failed() const + void rethrow_if_failed(AsyncStatus status) const { - if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Error || m_status.load(std::memory_order_relaxed) == AsyncStatus::Canceled) + if (status == AsyncStatus::Error || status == AsyncStatus::Canceled) { std::rethrow_exception(m_exception); } @@ -691,6 +716,11 @@ namespace std::experimental return std::move(m_result); } + TResult copy_return_value() noexcept + { + return m_result; + } + void return_value(TResult&& value) noexcept { m_result = std::move(value); @@ -730,13 +760,20 @@ namespace std::experimental return std::move(m_result); } + TResult copy_return_value() noexcept + { + return m_result; + } + void return_value(TResult&& value) noexcept { + winrt::slim_lock_guard const guard(this->m_lock); m_result = std::move(value); } void return_value(TResult const& value) noexcept { + winrt::slim_lock_guard const guard(this->m_lock); m_result = value; } diff --git a/test/old_tests/UnitTests/async.cpp b/test/old_tests/UnitTests/async.cpp index 385f240ea..4452e6db4 100644 --- a/test/old_tests/UnitTests/async.cpp +++ b/test/old_tests/UnitTests/async.cpp @@ -832,7 +832,8 @@ TEST_CASE("async, Cancel_IAsyncActionWithProgress") handle event { CreateEvent(nullptr, false, false, nullptr)}; IAsyncActionWithProgress async = Cancel_IAsyncActionWithProgress(event.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + // It is legal to read results of an incomplete WithProgress. + REQUIRE_NOTHROW(async.GetResults()); async.Cancel(); SetEvent(event.get()); // signal async to run @@ -862,7 +863,8 @@ TEST_CASE("async, Cancel_IAsyncActionWithProgress, 2") handle event { CreateEvent(nullptr, false, false, nullptr)}; IAsyncActionWithProgress async = Cancel_IAsyncActionWithProgress(event.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + // It is legal to read results of an incomplete WithProgress. + REQUIRE_NOTHROW(async.GetResults()); bool completed = false; bool objectMatches = false; @@ -952,7 +954,8 @@ TEST_CASE("async, Cancel_IAsyncOperationWithProgress") handle event { CreateEvent(nullptr, false, false, nullptr)}; IAsyncOperationWithProgress async = Cancel_IAsyncOperationWithProgress(event.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + // It is legal to read results of an incomplete WithProgress. + REQUIRE_NOTHROW(async.GetResults()); async.Cancel(); SetEvent(event.get()); // signal async to run @@ -982,7 +985,8 @@ TEST_CASE("async, Cancel_IAsyncOperationWithProgress, 2") handle event { CreateEvent(nullptr, false, false, nullptr)}; IAsyncOperationWithProgress async = Cancel_IAsyncOperationWithProgress(event.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + // It is legal to read results of an incomplete WithProgress. + REQUIRE_NOTHROW(async.GetResults()); bool completed = false; bool objectMatches = false; diff --git a/test/test/async_progress.cpp b/test/test/async_progress.cpp index a1ef121c8..1a7dc04fc 100644 --- a/test/test/async_progress.cpp +++ b/test/test/async_progress.cpp @@ -16,16 +16,17 @@ namespace progress(123); } - IAsyncOperationWithProgress Operation(HANDLE event) + IAsyncOperationWithProgress Operation(HANDLE event) { co_await resume_on_signal(event); // Invoke from a lambda to ensure that operator() is const. [progress = co_await get_progress_token()]() { + progress.set_result(L"working"); progress(123); }(); - co_return 1; + co_return L"done"; } template @@ -35,12 +36,21 @@ namespace handle start{ CreateEvent(nullptr, true, false, nullptr) }; auto async = make(start.get()); + using TResult = decltype(async.GetResults()); + bool progress = false; async.Progress([&](auto&& sender, int value) { progress = true; REQUIRE(async == sender); + REQUIRE(async.Status() == AsyncStatus::Started); + if constexpr (std::is_same_v) + { + REQUIRE(async.GetResults() == L"working"); + // Confirm that reading does not destroy partial results. + REQUIRE(async.GetResults() == L"working"); + } REQUIRE(value == 123); }); @@ -50,6 +60,19 @@ namespace REQUIRE(progress); REQUIRE(async.Status() == AsyncStatus::Completed); REQUIRE(async.ErrorCode() == S_OK); + + // Confirm that you can read results from a completed WithProgress multiple times. + // (We must allow this to avoid race conditions where a progress callback + // does async work and then tries to read an intermediate result, only to + // discover that the operation has already completed.) + if constexpr (std::is_same_v) + { + REQUIRE(async.GetResults() == L"done"); + } + else + { + REQUIRE_NOTHROW(async.GetResults()); + } } template diff --git a/test/test/async_result.cpp b/test/test/async_result.cpp index f02c10ce4..9c1accccb 100644 --- a/test/test/async_result.cpp +++ b/test/test/async_result.cpp @@ -40,7 +40,15 @@ namespace handle completed{ CreateEvent(nullptr, true, false, nullptr) }; auto async = make(start.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + if constexpr (has_async_progress) + { + // You're allowed to peek at partial results of IAsyncXxxWithProgress. + REQUIRE_NOTHROW(async.GetResults()); + } + else + { + REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + } async.Completed([&](auto&& sender, AsyncStatus status) { diff --git a/test/test/async_suspend.cpp b/test/test/async_suspend.cpp index 3c0fd791e..f7beee21d 100644 --- a/test/test/async_suspend.cpp +++ b/test/test/async_suspend.cpp @@ -49,7 +49,15 @@ namespace handle completed{ CreateEvent(nullptr, true, false, nullptr) }; auto async = make(start.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + if constexpr (has_async_progress) + { + // You're allowed to peek at partial results of IAsyncXxxWithProgress. + REQUIRE_NOTHROW(async.GetResults()); + } + else + { + REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + } async.Completed([&](auto&& sender, AsyncStatus status) { diff --git a/test/test/pch.h b/test/test/pch.h index ea74230cc..916e8fe8f 100644 --- a/test/test/pch.h +++ b/test/test/pch.h @@ -10,3 +10,39 @@ #include "catch.hpp" using namespace std::literals; + +// Extracts return and progress types from IAsyncXxx. + +template +struct async_traits; + +template<> +struct async_traits +{ + using progress_type = void; +}; + +template +struct async_traits> +{ + using progress_type = P; +}; + +template +struct async_traits> +{ + using progress_type = void; +}; + +template +struct async_traits> +{ + using progress_type = P; +}; + +template +using async_return_type = decltype(std::declval().GetResults()); +template +using async_progress_type = typename async_traits>::progress_type; +template +inline constexpr bool has_async_progress = !std::is_same_v>::progress_type>; diff --git a/test/test_win7/async_result.cpp b/test/test_win7/async_result.cpp index f02c10ce4..9c1accccb 100644 --- a/test/test_win7/async_result.cpp +++ b/test/test_win7/async_result.cpp @@ -40,7 +40,15 @@ namespace handle completed{ CreateEvent(nullptr, true, false, nullptr) }; auto async = make(start.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + if constexpr (has_async_progress) + { + // You're allowed to peek at partial results of IAsyncXxxWithProgress. + REQUIRE_NOTHROW(async.GetResults()); + } + else + { + REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + } async.Completed([&](auto&& sender, AsyncStatus status) { diff --git a/test/test_win7/async_suspend.cpp b/test/test_win7/async_suspend.cpp index 3c0fd791e..f7beee21d 100644 --- a/test/test_win7/async_suspend.cpp +++ b/test/test_win7/async_suspend.cpp @@ -49,7 +49,15 @@ namespace handle completed{ CreateEvent(nullptr, true, false, nullptr) }; auto async = make(start.get()); REQUIRE(async.Status() == AsyncStatus::Started); - REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + if constexpr (has_async_progress) + { + // You're allowed to peek at partial results of IAsyncXxxWithProgress. + REQUIRE_NOTHROW(async.GetResults()); + } + else + { + REQUIRE_THROWS_AS(async.GetResults(), hresult_illegal_method_call); + } async.Completed([&](auto&& sender, AsyncStatus status) { diff --git a/test/test_win7/pch.h b/test/test_win7/pch.h index 2f7ee3989..84bb40137 100644 --- a/test/test_win7/pch.h +++ b/test/test_win7/pch.h @@ -7,3 +7,39 @@ #include "catch.hpp" using namespace std::literals; + +// Extracts return and progress types from IAsyncXxx. + +template +struct async_traits; + +template<> +struct async_traits +{ + using progress_type = void; +}; + +template +struct async_traits> +{ + using progress_type = P; +}; + +template +struct async_traits> +{ + using progress_type = void; +}; + +template +struct async_traits> +{ + using progress_type = P; +}; + +template +using async_return_type = decltype(std::declval().GetResults()); +template +using async_progress_type = typename async_traits>::progress_type; +template +inline constexpr bool has_async_progress = !std::is_same_v>::progress_type>;