Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions strings/base_implements.h
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ namespace winrt::impl

using is_agile = std::negation<std::disjunction<std::is_same<non_agile, I>...>>;
using is_inspectable = std::disjunction<std::is_base_of<Windows::Foundation::IInspectable, I>...>;
using is_weak_ref_source = std::conjunction<is_inspectable, std::negation<std::disjunction<std::is_same<no_weak_ref, I>...>>>;
using is_weak_ref_source = std::negation<std::disjunction<std::is_same<no_weak_ref, I>...>>;
using use_module_lock = std::negation<std::disjunction<std::is_same<no_module_lock, I>...>>;
using weak_ref_t = impl::weak_ref<is_agile::value, use_module_lock::value>;

Expand Down Expand Up @@ -1125,57 +1125,64 @@ namespace winrt::impl

impl::IWeakReferenceSource* make_weak_ref() noexcept
{
static_assert(is_weak_ref_source::value, "This is only for weak ref support.");
uintptr_t count_or_pointer = m_references.load(std::memory_order_relaxed);

if (is_weak_ref(count_or_pointer))
if constexpr (is_weak_ref_source::value)
{
return decode_weak_ref(count_or_pointer)->get_source();
}

com_ptr<weak_ref_t> weak_ref;
*weak_ref.put() = new (std::nothrow) weak_ref_t(get_unknown(), static_cast<uint32_t>(count_or_pointer));
uintptr_t count_or_pointer = m_references.load(std::memory_order_relaxed);

if (!weak_ref)
{
return nullptr;
}
if (is_weak_ref(count_or_pointer))
{
return decode_weak_ref(count_or_pointer)->get_source();
}

uintptr_t const encoding = encode_weak_ref(weak_ref.get());
com_ptr<weak_ref_t> weak_ref;
*weak_ref.put() = new (std::nothrow) weak_ref_t(get_unknown(), static_cast<uint32_t>(count_or_pointer));

for (;;)
{
if (m_references.compare_exchange_weak(count_or_pointer, encoding, std::memory_order_acq_rel, std::memory_order_relaxed))
if (!weak_ref)
{
impl::IWeakReferenceSource* result = weak_ref->get_source();
detach_abi(weak_ref);
return result;
return nullptr;
}

if (is_weak_ref(count_or_pointer))
uintptr_t const encoding = encode_weak_ref(weak_ref.get());

for (;;)
{
return decode_weak_ref(count_or_pointer)->get_source();
}
if (m_references.compare_exchange_weak(count_or_pointer, encoding, std::memory_order_acq_rel, std::memory_order_relaxed))
{
impl::IWeakReferenceSource* result = weak_ref->get_source();
detach_abi(weak_ref);
return result;
}

if (is_weak_ref(count_or_pointer))
{
return decode_weak_ref(count_or_pointer)->get_source();
}

weak_ref->set_strong(static_cast<uint32_t>(count_or_pointer));
weak_ref->set_strong(static_cast<uint32_t>(count_or_pointer));
}
}
else
{
static_assert(is_weak_ref_source::value, "Weak references are not supported because no_weak_ref was specified.");
return nullptr;
}
}

static bool is_weak_ref(intptr_t const value) noexcept
{
static_assert(is_weak_ref_source::value, "This is only for weak ref support.");
static_assert(is_weak_ref_source::value, "Weak references are not supported because no_weak_ref was specified.");
return value < 0;
}

static weak_ref_t* decode_weak_ref(uintptr_t const value) noexcept
{
static_assert(is_weak_ref_source::value, "This is only for weak ref support.");
static_assert(is_weak_ref_source::value, "Weak references are not supported because no_weak_ref was specified.");
return reinterpret_cast<weak_ref_t*>(value << 1);
}

static uintptr_t encode_weak_ref(weak_ref_t* value) noexcept
{
static_assert(is_weak_ref_source::value, "This is only for weak ref support.");
static_assert(is_weak_ref_source::value, "Weak references are not supported because no_weak_ref was specified.");
constexpr uintptr_t pointer_flag = static_cast<uintptr_t>(1) << ((sizeof(uintptr_t) * 8) - 1);
WINRT_ASSERT((reinterpret_cast<uintptr_t>(value) & 1) == 0);
return (reinterpret_cast<uintptr_t>(value) >> 1) | pointer_flag;
Expand Down
40 changes: 36 additions & 4 deletions test/old_tests/UnitTests/weak.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace
}
};

struct NoWeak : implements<NoWeak, ::IUnknown>
struct WeakClassicCom : implements<WeakClassicCom, ::IUnknown>
{
};

Expand Down Expand Up @@ -161,6 +161,26 @@ TEST_CASE("weak,source")
REQUIRE(b.ToString() == L"Weak");
}

SECTION("classic-com")
{
com_ptr<::IUnknown> a = make<WeakClassicCom>();

weak_ref<::IUnknown> w = a;
com_ptr<::IUnknown> b = w.get();
REQUIRE(b == a);

// still one outstanding reference
b = nullptr;
b = w.get();
REQUIRE(b != nullptr);

// no outstanding references
a = nullptr;
b = nullptr;
b = w.get();
REQUIRE(b == nullptr);
}

// Verify that deduction guides work.
static_assert(std::is_same_v<weak_ref<IStringable>, decltype(weak_ref(IStringable()))>);
static_assert(std::is_same_v<weak_ref<Uri>, decltype(weak_ref(std::declval<Uri>()))>);
Expand Down Expand Up @@ -206,12 +226,24 @@ TEST_CASE("weak,QI")
REQUIRE(ref.as<::IUnknown>() != object.as<::IUnknown>());
}

SECTION("no-weak")
SECTION("weak-classic-com")
{
com_ptr<::IUnknown> object = make<NoWeak>();
com_ptr<::IUnknown> object = make<WeakClassicCom>();
REQUIRE(!object.try_as<Windows::Foundation::IInspectable>());
REQUIRE(!object.try_as<winrt::impl::IWeakReferenceSource>());
REQUIRE(object.try_as<winrt::impl::IWeakReferenceSource>());
REQUIRE(!object.try_as<winrt::impl::IWeakReference>());

com_ptr<winrt::impl::IWeakReferenceSource> source = object.as<winrt::impl::IWeakReferenceSource>();
REQUIRE(!source.try_as<winrt::impl::IWeakReference>());
REQUIRE(source.try_as<winrt::impl::IWeakReferenceSource>());
REQUIRE(object.as<::IUnknown>() == source.as<::IUnknown>());

com_ptr<winrt::impl::IWeakReference> ref;
REQUIRE(S_OK == source->GetWeakReference(ref.put()));
REQUIRE(!ref.try_as<winrt::impl::IWeakReferenceSource>());
REQUIRE(!ref.try_as<Windows::Foundation::IInspectable>());
REQUIRE(ref.as<winrt::impl::IWeakReference>() == ref);
REQUIRE(ref.as<::IUnknown>() != object.as<::IUnknown>());
}

SECTION("factory")
Expand Down