diff --git a/strings/base_implements.h b/strings/base_implements.h index df51f052f..e2c6f3a12 100644 --- a/strings/base_implements.h +++ b/strings/base_implements.h @@ -1063,7 +1063,7 @@ namespace winrt::impl using is_agile = std::negation...>>; using is_inspectable = std::disjunction...>; - using is_weak_ref_source = std::conjunction...>>>; + using is_weak_ref_source = std::negation...>>; using use_module_lock = std::negation...>>; using weak_ref_t = impl::weak_ref; @@ -1125,57 +1125,64 @@ namespace winrt::impl impl::IWeakReferenceSource* make_weak_ref() noexcept { - static_assert(is_weak_ref_source::value, "This is only for weak ref support."); - uintptr_t count_or_pointer = m_references.load(std::memory_order_relaxed); - - if (is_weak_ref(count_or_pointer)) + if constexpr (is_weak_ref_source::value) { - return decode_weak_ref(count_or_pointer)->get_source(); - } - - com_ptr weak_ref; - *weak_ref.put() = new (std::nothrow) weak_ref_t(get_unknown(), static_cast(count_or_pointer)); + uintptr_t count_or_pointer = m_references.load(std::memory_order_relaxed); - if (!weak_ref) - { - return nullptr; - } + if (is_weak_ref(count_or_pointer)) + { + return decode_weak_ref(count_or_pointer)->get_source(); + } - uintptr_t const encoding = encode_weak_ref(weak_ref.get()); + com_ptr weak_ref; + *weak_ref.put() = new (std::nothrow) weak_ref_t(get_unknown(), static_cast(count_or_pointer)); - for (;;) - { - if (m_references.compare_exchange_weak(count_or_pointer, encoding, std::memory_order_acq_rel, std::memory_order_relaxed)) + if (!weak_ref) { - impl::IWeakReferenceSource* result = weak_ref->get_source(); - detach_abi(weak_ref); - return result; + return nullptr; } - if (is_weak_ref(count_or_pointer)) + uintptr_t const encoding = encode_weak_ref(weak_ref.get()); + + for (;;) { - return decode_weak_ref(count_or_pointer)->get_source(); - } + if (m_references.compare_exchange_weak(count_or_pointer, encoding, std::memory_order_acq_rel, std::memory_order_relaxed)) + { + impl::IWeakReferenceSource* result = weak_ref->get_source(); + detach_abi(weak_ref); + return result; + } + + if (is_weak_ref(count_or_pointer)) + { + return decode_weak_ref(count_or_pointer)->get_source(); + } - weak_ref->set_strong(static_cast(count_or_pointer)); + weak_ref->set_strong(static_cast(count_or_pointer)); + } + } + else + { + static_assert(is_weak_ref_source::value, "Weak references are not supported because no_weak_ref was specified."); + return nullptr; } } static bool is_weak_ref(intptr_t const value) noexcept { - static_assert(is_weak_ref_source::value, "This is only for weak ref support."); + static_assert(is_weak_ref_source::value, "Weak references are not supported because no_weak_ref was specified."); return value < 0; } static weak_ref_t* decode_weak_ref(uintptr_t const value) noexcept { - static_assert(is_weak_ref_source::value, "This is only for weak ref support."); + static_assert(is_weak_ref_source::value, "Weak references are not supported because no_weak_ref was specified."); return reinterpret_cast(value << 1); } static uintptr_t encode_weak_ref(weak_ref_t* value) noexcept { - static_assert(is_weak_ref_source::value, "This is only for weak ref support."); + static_assert(is_weak_ref_source::value, "Weak references are not supported because no_weak_ref was specified."); constexpr uintptr_t pointer_flag = static_cast(1) << ((sizeof(uintptr_t) * 8) - 1); WINRT_ASSERT((reinterpret_cast(value) & 1) == 0); return (reinterpret_cast(value) >> 1) | pointer_flag; diff --git a/test/old_tests/UnitTests/weak.cpp b/test/old_tests/UnitTests/weak.cpp index ba0fb5d86..d93377df5 100644 --- a/test/old_tests/UnitTests/weak.cpp +++ b/test/old_tests/UnitTests/weak.cpp @@ -22,7 +22,7 @@ namespace } }; - struct NoWeak : implements + struct WeakClassicCom : implements { }; @@ -161,6 +161,26 @@ TEST_CASE("weak,source") REQUIRE(b.ToString() == L"Weak"); } + SECTION("classic-com") + { + com_ptr<::IUnknown> a = make(); + + weak_ref<::IUnknown> w = a; + com_ptr<::IUnknown> b = w.get(); + REQUIRE(b == a); + + // still one outstanding reference + b = nullptr; + b = w.get(); + REQUIRE(b != nullptr); + + // no outstanding references + a = nullptr; + b = nullptr; + b = w.get(); + REQUIRE(b == nullptr); + } + // Verify that deduction guides work. static_assert(std::is_same_v, decltype(weak_ref(IStringable()))>); static_assert(std::is_same_v, decltype(weak_ref(std::declval()))>); @@ -206,12 +226,24 @@ TEST_CASE("weak,QI") REQUIRE(ref.as<::IUnknown>() != object.as<::IUnknown>()); } - SECTION("no-weak") + SECTION("weak-classic-com") { - com_ptr<::IUnknown> object = make(); + com_ptr<::IUnknown> object = make(); REQUIRE(!object.try_as()); - REQUIRE(!object.try_as()); + REQUIRE(object.try_as()); REQUIRE(!object.try_as()); + + com_ptr source = object.as(); + REQUIRE(!source.try_as()); + REQUIRE(source.try_as()); + REQUIRE(object.as<::IUnknown>() == source.as<::IUnknown>()); + + com_ptr ref; + REQUIRE(S_OK == source->GetWeakReference(ref.put())); + REQUIRE(!ref.try_as()); + REQUIRE(!ref.try_as()); + REQUIRE(ref.as() == ref); + REQUIRE(ref.as<::IUnknown>() != object.as<::IUnknown>()); } SECTION("factory")