diff --git a/strings/base_weak_ref.h b/strings/base_weak_ref.h index 24fe8fb91..c98eb93a7 100644 --- a/strings/base_weak_ref.h +++ b/strings/base_weak_ref.h @@ -6,28 +6,17 @@ WINRT_EXPORT namespace winrt { weak_ref(std::nullptr_t = nullptr) noexcept {} - weak_ref(impl::com_ref const& object) + template const&, typename = std::enable_if_t const&>>> + weak_ref(U&& object) { - if (object) - { - if constexpr(impl::is_implements_v) - { - m_ref = std::move(object->get_weak().m_ref); - } - else - { - // An access violation (crash) on the following line means that the object does not support weak references. - // Avoid using weak_ref/auto_revoke with such objects. - check_hresult(object.template try_as()->GetWeakReference(m_ref.put())); - } - } + from_com_ref(static_cast const&>(object)); } - [[nodiscard]] impl::com_ref get() const noexcept + [[nodiscard]] auto get() const noexcept { if (!m_ref) { - return nullptr; + return impl::com_ref{ nullptr }; } if constexpr(impl::is_implements_v) @@ -36,13 +25,13 @@ WINRT_EXPORT namespace winrt m_ref->Resolve(guid_of(), put_abi(temp)); void* result = get_self(temp); detach_abi(temp); - return { result, take_ownership_from_abi }; + return impl::com_ref{ result, take_ownership_from_abi }; } else { void* result{}; m_ref->Resolve(guid_of(), &result); - return { result, take_ownership_from_abi }; + return impl::com_ref{ result, take_ownership_from_abi }; } } @@ -58,6 +47,24 @@ WINRT_EXPORT namespace winrt private: + template + void from_com_ref(U&& object) + { + if (object) + { + if constexpr (impl::is_implements_v) + { + m_ref = std::move(object->get_weak().m_ref); + } + else + { + // An access violation (crash) on the following line means that the object does not support weak references. + // Avoid using weak_ref/auto_revoke with such objects. + check_hresult(object.template try_as()->GetWeakReference(m_ref.put())); + } + } + } + com_ptr m_ref; }; diff --git a/test/old_tests/UnitTests/weak.cpp b/test/old_tests/UnitTests/weak.cpp index 650b6ae0f..b253e6010 100644 --- a/test/old_tests/UnitTests/weak.cpp +++ b/test/old_tests/UnitTests/weak.cpp @@ -49,6 +49,25 @@ namespace return L"WeakNoModuleLock"; } }; + + struct WeakWithSelfReference : implements + { + winrt::weak_ref weak_self = get_weak(); + + hstring ToString() + { + // Verify that the weak reference works as long as the object is alive. + REQUIRE(weak_self.get().get() == this); + + return L"WeakWithSelfReference"; + } + + ~WeakWithSelfReference() + { + // Verify that the weak reference cannot be resolved once destruction begins. + REQUIRE(weak_self.get() == nullptr); + } + }; } TEST_CASE("weak,source") @@ -313,6 +332,49 @@ TEST_CASE("weak,comparison") REQUIRE(refA1 != refNothing); } +TEST_CASE("weak,assignment") +{ + IStringable object = make(); + weak_ref ref1 = object; + + // Move constructor + weak_ref ref2 = std::move(ref1); + REQUIRE(ref1 == nullptr); + + // Copy constructor + weak_ref ref3 = ref2; + REQUIRE(ref2 == ref3); + + // Copy assignment + ref1 = ref2; + REQUIRE(ref1 == ref2); + REQUIRE(ref1 == ref3); + + // Move assignment + ref1 = std::move(ref2); + REQUIRE(ref2 == nullptr); + REQUIRE(ref1 == ref3); + + // Copy assignment from const + ref2 = static_cast const&>(ref1); + REQUIRE(ref1 == ref2); + + // Move assignment from const + ref1 = static_cast const&&>(ref2); + REQUIRE(ref1 == ref2); + + // Constructed from com_ref braced constructor + weak_ref yikes{ { nullptr, take_ownership_from_abi } }; + + // Not constructible from L"" (because Uri constructor is explicit) + static_assert(!std::is_constructible_v, const wchar_t*>); + + // Constructible from com_ptr because com_ptr is + // implicitly convertible to com_ptr. + struct Derived : WeakWithSelfReference {}; + weak_ref decay{ winrt::com_ptr{nullptr} }; +} + TEST_CASE("weak,module_lock") { uint32_t object_count = get_module_lock(); @@ -344,3 +406,10 @@ TEST_CASE("weak,no_module_lock") REQUIRE(get_module_lock() == object_count); } +TEST_CASE("weak,self") +{ + // The REQUIRE statements are in the WeakWithSelfReference class itself. + IStringable a = make(); + a.ToString(); + a = nullptr; +}