diff --git a/test/test/multi_threaded_common.h b/test/test/multi_threaded_common.h index 13ebe56a3..f8013f290 100644 --- a/test/test/multi_threaded_common.h +++ b/test/test/multi_threaded_common.h @@ -1,73 +1,323 @@ #pragma once -#include - -struct unique_thread +namespace concurrent_collections { - std::thread thread; - std::exception_ptr ex; + template // int or IInspectable + T conditional_box(int value) + { + if constexpr (std::is_same_v) + { + return value; + } + else + { + return winrt::box_value(value); + } + } + + template + int conditional_unbox(T const& value) + { + if constexpr (std::is_same_v) + { + return value; + } + else + { + return winrt::unbox_value(value); + } + } - unique_thread() = default; + // When debugging, you may want to increase this so you can set breakpoints + // without triggering the timeouts. + // + // This determines how long we wait before we decide that our intentionally-frozen + // threads have triggered a deadlock, which is expected when using multithread-safe collections: + // Multithread-safe collections will wait for one thread to exit the collection before + // allowing the next one to enter, and freezing inside the collection will cause a hang. + // The reason for freezing inside the collection is to confirm that the other thread will + // indeed wait for the first thread to finish before proceeding. + static inline constexpr DWORD DEADLOCK_TIMEOUT = 10; - template - unique_thread(Func&& fn, Args&&... args) +#pragma region collection hooks + + // The collection hook injects a delay when a particular action occurs + // for the first time on the background thread. + enum class collection_action { - thread = std::thread([this, fn = std::forward(fn)](auto&&... args) + none, push_back, insert, erase, at, lookup + }; + + // All of our concurrency tests consists of starting an + // operation on the background thread and then while that operation is + // in progress, performing some other operation on the main thread and + // verifying that nothing bad happens. + // + // | Background thread | Main thread | + // |----------------------|-------------------| + // | something() | | Step 1 + // | | vector.something() | | + // | | | | | <--- pause background thread, start main thread + // | | | | foreground() | Step 2 + // | | | | | <--- resume background thread + // | | | do_the_thing | | Step 3 + + struct collection_hook + { + collection_action race_action = collection_action::none; + int step = 0; + DWORD mainThreadId = GetCurrentThreadId(); + + collection_hook() = default; + + void on_action(collection_action action) { - try + if ((action == race_action) && (GetCurrentThreadId() != mainThreadId)) { - fn(std::forward(args)...); + race_action = collection_action::none; + GoToStep(2); + WaitForStep(3); } - catch (...) + } + + template + void race(collection_action action, Background&& background, Foreground&& foreground) + { + race_action = action; + step = 1; + + auto task = [](auto&& background) -> winrt::Windows::Foundation::IAsyncAction + { + co_await winrt::resume_background(); + background(); + }(background); + + WaitForStep(2); + foreground(); + GoToStep(3); + + // Wait for background task to complete. + task.get(); + race_action = collection_action::none; + } + + private: + // The hooks exist so we can proceed through a sequence of + // steps in order to force race conditions. These helper function + // control the progress through those steps. + + void GoToStep(int value) + { + if (step < value) { - ex = std::current_exception(); + step = value; + WakeByAddressAll(&step); } - }, std::forward(args)...); - } + } - ~unique_thread() noexcept(false) - { - if (thread.joinable()) + bool WaitForStep(int value, DWORD timeout = DEADLOCK_TIMEOUT) { - join(); + int current; + while ((current = step) < value) + { + if (!WaitOnAddress(&step, ¤t, sizeof(current), timeout)) + { + return false; // timed out + } + } + return true; } - } + }; - unique_thread(unique_thread&&) = default; - unique_thread& operator=(unique_thread&&) = default; +#pragma endregion - void join() +#pragma region iterator wrapper + template + struct concurrency_checked_random_access_iterator : Iterator { - thread.join(); - if (ex) + using container = Container; + using iterator = Iterator; + + using size_type = typename container::size_type; + + using difference_type = typename iterator::difference_type; + using value_type = typename iterator::value_type; + using pointer = typename iterator::pointer; + using reference = typename iterator::reference; + using iterator_category = typename iterator::iterator_category; + + container const* owner; + + concurrency_checked_random_access_iterator() : owner(nullptr) {} + concurrency_checked_random_access_iterator(container const* c, iterator it) : owner(c), iterator(it) {} + + // Implicit conversion from non-const iterator to const iterator. + template>> + concurrency_checked_random_access_iterator(concurrency_checked_random_access_iterator other) : owner(other.owner), iterator(other.inner()) { } + + concurrency_checked_random_access_iterator(concurrency_checked_random_access_iterator const&) = default; + concurrency_checked_random_access_iterator& operator=(concurrency_checked_random_access_iterator const&) = default; + + iterator& inner() { return static_cast(*this); } + iterator const& inner() const { return static_cast(*this); } + + reference operator*() const { - std::rethrow_exception(ex); + return owner->dereference_iterator(inner()); } - } -}; -template // int or IInspectable -T conditional_box(int value) -{ - if constexpr (std::is_same_v) - { - return value; - } - else + // inherited: pointer operator->() const; + + concurrency_checked_random_access_iterator& operator++() + { + ++inner(); + return *this; + } + + concurrency_checked_random_access_iterator& operator++(int) + { + auto prev = *this; + ++inner(); + return prev; + } + + concurrency_checked_random_access_iterator& operator--() + { + --inner(); + return *this; + } + + concurrency_checked_random_access_iterator& operator--(int) + { + auto prev = *this; + --inner(); + return prev; + } + + concurrency_checked_random_access_iterator& operator+=(difference_type offset) + { + inner() += offset; + return *this; + } + + concurrency_checked_random_access_iterator operator+(difference_type pos) const + { + return { owner, inner() + pos }; + } + + concurrency_checked_random_access_iterator& operator-=(difference_type offset) + { + inner() -= offset; + return *this; + } + + concurrency_checked_random_access_iterator operator-(difference_type pos) const + { + return { owner, inner() - pos }; + } + + difference_type operator-(concurrency_checked_random_access_iterator const& other) const + { + return inner() - other.inner(); + } + + reference operator[](size_type pos) const + { + return owner->dereference_iterator(inner() + pos); + } + + // inherited: all comparison operators + }; + + // "integer + iterator" must be defined as a free operator. + template + concurrency_checked_random_access_iterator operator+( + typename concurrency_checked_random_access_iterator::difference_type offset, + concurrency_checked_random_access_iterator it) { - return winrt::box_value(value); + return it += offset; } -} +#pragma endregion -template -int conditional_unbox(T const& value) -{ - if constexpr (std::is_same_v) + struct concurrency_guard { - return value; - } - else + // Clients can use the hook to alter behavior. + std::shared_ptr hook = std::make_shared(); + + concurrency_guard() = default; + concurrency_guard(concurrency_guard const& other) noexcept + : m_lock(0), hook(other.hook) + { + auto guard = other.lock_nonconst(); + } + + void call_hook(collection_action action) const + { + return hook->on_action(action); + } + + struct const_access_guard + { + concurrency_guard const* owner; + + const_access_guard(concurrency_guard const* v) : owner(v) + { + CHECK(++owner->m_lock > 0); + } + + ~const_access_guard() + { + --owner->m_lock; + } + }; + + struct nonconst_access_guard + { + concurrency_guard const* owner; + + nonconst_access_guard(concurrency_guard const* v) : owner(v) + { + CHECK(--owner->m_lock == -1); + } + + ~nonconst_access_guard() + { + owner->m_lock = 0; + } + }; + + const_access_guard lock_const() const + { + return { this }; + } + + nonconst_access_guard lock_nonconst() const + { + return { this }; + } + + private: + // 0 = not being accessed + // -1 = a thread is inside a non-const method + // positive = number of threads inside a const method + + std::atomic mutable m_lock; + }; + + template + struct deadlock_object : winrt::implements, winrt::Windows::Foundation::IInspectable> { - return winrt::unbox_value(value); - } + Collection collection; + + deadlock_object(Collection c) : collection(c) {} + + static void final_release(std::unique_ptr self) + { + // Make sure this doesn't deadlock. There are cases where an object's destructor + // triggers a cascade of destruction, and some of the cascade destructors try + // to talk to the same collection that the original object was removed from. + self->collection.Clear(); + } + }; + } diff --git a/test/test/multi_threaded_map.cpp b/test/test/multi_threaded_map.cpp index 517083df5..2f81d3789 100644 --- a/test/test/multi_threaded_map.cpp +++ b/test/test/multi_threaded_map.cpp @@ -8,347 +8,330 @@ using namespace winrt; using namespace Windows::Foundation; using namespace Windows::Foundation::Collections; +using namespace concurrent_collections; // Map correctness tests exist elsewhere. These tests are strictly geared toward testing multi threaded functionality -// -// Be careful with use of REQUIRE. -// -// 1. REQUIRE is not concurrency-safe. Don't call it from two threads simultaneously. -// 2. The number of calls to REQUIRE should be the consistent in the face of nondeterminism. -// This makes the "(x assertions in y test cases)" consistent. -// -// If you need to check something from a background thread, or where the number -// of iterations can vary from run to run, use winrt::check_bool, which still -// fails the test but doesn't run afoul of REQUIRE's limitations. - -template -static void test_single_reader_single_writer(IMap const& map) -{ - static constexpr int final_size = 10000; - // Insert / HasKey / Lookup - unique_thread t([&] +namespace +{ + // We use a customized container that mimics std::map and which + // validates that C++ concurrency rules are observed. + // C++ rules for library types are that concurrent use of const methods is allowed, + // but no method call may be concurrent with a non-const method. (Const methods may + // be "shared", but non-const methods are "exclusive".) + // + // NOTE! As the C++/WinRT implementation changes, you may need to add additional members + // to our mock. + // + // The regular single_threaded_map and multi_threaded_map functions require std::map + // or std::unordered_map, so we bypass them and go directly to the underlying classes, + // which take any container that acts map-like. + + enum class MapKind { - for (int i = 0; i < final_size; ++i) + IMap, + IObservableMap, + }; + + // Change the next line to "#if 0" to use a single-threaded map and confirm that every test fails. + // The scenarios use "CHECK" instead of "REQUIRE" so that they continue running even on failure. + // That way, you can just step through the entire test in single-threaded mode and confirm that + // something bad happens at each scenario. +#if 1 + template + using custom_threaded_map = winrt::impl::multi_threaded_map; + + template + using custom_observable_map = winrt::impl::multi_threaded_observable_map; +#else + template + using custom_threaded_map = winrt::impl::input_map; + + template + using custom_observable_map = winrt::impl::observable_map; +#endif + + template + auto make_threaded_map(Container&& values) + { + using K = typename Container::key_type; + using V = typename Container::mapped_type; + + if constexpr (kind == MapKind::IMap) { - map.Insert(i, conditional_box(i)); - std::this_thread::yield(); + return static_cast>(winrt::make>(std::move(values))); } - }); + else + { + return static_cast>(winrt::make>(std::move(values))); + } + } - while (true) +#pragma region map wrapper + // Add more wrapper methods as necessary. + // (Turns out we don't use many features of std::map and std::unordered_map.) + template, typename Allocator = std::allocator>> + struct concurrency_checked_map : private std::map, concurrency_guard { - int i = 0; - auto beginSize = map.Size(); - for (; i < final_size; ++i) + using inner = typename concurrency_checked_map::map; + using key_type = typename inner::key_type; + using mapped_type = typename inner::mapped_type; + using value_type = typename inner::value_type; + using size_type = typename inner::size_type; + using difference_type = typename inner::difference_type; + using allocator_type = typename inner::allocator_type; + using reference = typename inner::reference; + using const_reference = typename inner::const_reference; + using pointer = typename inner::pointer; + using const_pointer = typename inner::const_pointer; + using iterator = concurrency_checked_random_access_iterator; + using const_iterator = concurrency_checked_random_access_iterator; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + using node_type = typename inner::node_type; + + mapped_type& operator[](const key_type& key) { - if (!map.HasKey(i)) - { - check_bool(static_cast(i) >= beginSize); - break; - } + auto guard = concurrency_guard::lock_nonconst(); + concurrency_guard::call_hook(collection_action::at); + return { this, inner::begin() }; + } - check_bool(conditional_unbox(map.Lookup(i)) == i); + iterator begin() + { + auto guard = concurrency_guard::lock_nonconst(); + return { this, inner::begin() }; } - if (i == final_size) + const_iterator begin() const { - break; + auto guard = concurrency_guard::lock_const(); + return { this, inner::begin() }; } - } -} -template -static void test_iterator_invalidation(IMap const& map) -{ - static constexpr int size = 10; + iterator end() + { + auto guard = concurrency_guard::lock_nonconst(); + return { this, inner::end() }; + } - // Remove / Insert / First / HasCurrent / MoveNext / Current - for (int i = 0; i < size; ++i) - { - map.Insert(i, conditional_box(i)); - } + const_iterator end() const + { + auto guard = concurrency_guard::lock_const(); + return { this, inner::end() }; + } - volatile bool done = false; - unique_thread t([&] - { - // Since the underlying storage is std::map, it's actually quite hard to hit UB that has an observable side - // effect, making it hard to have a meaningful test. The idea here is to remove and re-insert the "first" - // element in a tight loop so that enumeration is likely to hit a concurrent access that's actually meaningful. - // Even then, failures really only occur with a single threaded collection when building Debug - while (!done) + size_type size() const { - map.Remove(0); - map.Insert(0, conditional_box(0)); + auto guard = concurrency_guard::lock_const(); + return inner::size(); } - }); - int exceptionCount = 0; + void clear() + { + auto guard = concurrency_guard::lock_nonconst(); + return inner::clear(); + } - for (int i = 0; i < 10000; ++i) - { - try + void swap(concurrency_checked_map& other) { - int count = 0; - for (auto itr = map.First(); itr.HasCurrent(); itr.MoveNext()) - { - auto pair = itr.Current(); - check_bool(pair.Key() == conditional_unbox(pair.Value())); - ++count; - } - check_bool(count >= (size - 1)); - check_bool(count <= size); + auto guard = concurrency_guard::lock_nonconst(); + inner::swap(other); } - catch (hresult_changed_state const&) + + template + std::pair emplace(Args&&... args) { - ++exceptionCount; + auto guard = concurrency_guard::lock_nonconst(); + concurrency_guard::call_hook(collection_action::insert); + auto [it, inserted] = inner::emplace(std::forward(args)...); + return { { this, it }, inserted }; } - } - done = true; - // In reality, this number should be quite large; much larger than the 50 validated here - REQUIRE(exceptionCount >= 50); -} + node_type extract(const_iterator pos) + { + auto guard = concurrency_guard::lock_nonconst(); + concurrency_guard::call_hook(collection_action::erase); + return inner::extract(pos); + } -template -static void test_concurrent_iteration(IMap const& map) -{ - static constexpr int size = 10000; + const_iterator find(const K& key) const + { + auto guard = concurrency_guard::lock_const(); + concurrency_guard::call_hook(collection_action::lookup); + return { this, inner::find(key) }; + } + }; +#pragma endregion - // Current / HasCurrent + template + void test_map_concurrency() { - for (int i = 0; i < size; ++i) + auto raw = concurrency_checked_map(); + auto hook = raw.hook; + + // Convert the raw_map into the desired Windows Runtime map interface. + auto m = make_threaded_map(std::move(raw)); + + auto race = [&](collection_action action, auto&& background, auto&& foreground) { - map.Insert(i, conditional_box(i)); - } + // Map initial contents are [1] = 1, [2] = 2. + m.Clear(); + m.Insert(1, conditional_box(1)); + m.Insert(2, conditional_box(2)); + + hook->race(action, background, foreground); + }; + + // Verify that Insert does not run concurrently with HasKey(). + race(collection_action::insert, [&] + { + m.Insert(42, conditional_box(42)); + }, [&] + { + CHECK(m.HasKey(2)); + }); + + // Verify that Insert does not run concurrently with Lookup(). + race(collection_action::insert, [&] + { + m.Insert(42, conditional_box(42)); + }, [&] + { + CHECK(conditional_unbox(m.Lookup(2))); + }); + + // Verify that Insert does not run concurrently with another Insert(). + race(collection_action::insert, [&] + { + m.Insert(43, conditional_box(43)); + }, [&] + { + m.Insert(43, conditional_box(43)); + }); - auto itr = map.First(); - unique_thread threads[2]; - int increments[std::size(threads)] = {}; - for (int i = 0; i < std::size(threads); ++i) + // Iterator invalidation tests are a little different because we perform + // the mutation from the foreground thread after the read operation + // has begun on the background thread. + // + // Verify that iterator invalidation doesn't race against + // iterator use. { - threads[i] = unique_thread([&itr, &increments, i] + // Current vs Remove + IKeyValuePair kvp; + race(collection_action::at, [&] { - int last = -1; - while (true) + try + { + kvp = m.First().Current(); + } + catch (hresult_error const&) { - try - { - // NOTE: Because there is no atomic "get and increment" function on IIterator, the best we can do is - // validate that we're getting valid increasing values, e.g. as opposed to validating that we read - // all unique values. - auto val = itr.Current().Key(); - check_bool(val > last); - check_bool(val < size); - last = val; - if (!itr.MoveNext()) - { - break; - } - - // MoveNext is the only synchronized operation that has a predictable side effect we can validate - ++increments[i]; - } - catch (hresult_error const&) - { - // There's no "get if" function, so concurrent increment past the end is always possible... - check_bool(!itr.HasCurrent()); - break; - } } + }, [&] + { + m.Remove(1); }); + CHECK((kvp && conditional_unbox(kvp.Value()) == 1)); } - - for (auto& t : threads) { - t.join(); + // MoveNext vs Remove + bool moved = false; + race(collection_action::at, [&] + { + try + { + moved = m.First().MoveNext(); + } + catch (hresult_error const&) + { + } + }, [&] + { + m.Remove(1); + }); + CHECK(moved); } - - REQUIRE(!itr.HasCurrent()); - - auto totalIncrements = std::accumulate(std::begin(increments), std::end(increments), 0); - REQUIRE(totalIncrements == (size - 1)); - } - - // HasCurrent / GetMany - { - auto itr = map.First(); - unique_thread threads[2]; - int totals[std::size(threads)] = {}; - for (int i = 0; i < std::size(threads); ++i) { - threads[i] = unique_thread([&itr, &totals, i] + // Current vs Insert + IKeyValuePair kvp; + race(collection_action::at, [&] { - IKeyValuePair vals[10]; - while (itr.HasCurrent()) + try + { + kvp = m.First().Current(); + } + catch (hresult_error const&) { - // Unlike 'Current', 'GetMany' _is_ atomic in regards to read+increment - auto len = itr.GetMany(vals); - totals[i] += std::accumulate(vals, vals + len, 0, [](int curr, auto const& next) { return curr + next.Key(); }); } + }, [&] + { + m.Insert(42, conditional_box(42)); }); + CHECK((kvp && conditional_unbox(kvp.Value()) == 1)); } - - for (auto& t : threads) { - t.join(); + // MoveNext vs Insert + bool moved = false; + race(collection_action::at, [&] + { + try + { + moved = m.First().MoveNext(); + } + catch (hresult_error const&) + { + } + }, [&] + { + m.Insert(42, conditional_box(42)); + }); + CHECK(moved); } - // sum(i = 1 -> N){i} = N * (N + 1) / 2 - auto total = std::accumulate(std::begin(totals), std::end(totals), 0); - REQUIRE(total == (size * (size - 1) / 2)); - } -} - -template -static void test_multi_writer(IMap const& map) -{ - // Large enough that several threads should be executing concurrently - static constexpr uint32_t size = 10000; - static constexpr size_t threadCount = 8; - - // Insert - unique_thread threads[threadCount]; - for (int i = 0; i < threadCount; ++i) - { - threads[i] = unique_thread([&map, i] { - auto off = i * size; - for (int j = 0; j < size; ++j) + // Verify that concurrent iteration works via GetMany(), which is atomic. + // (Current + MoveNext is non-atomic and can result in two threads + // both grabbing the same Current and then moving two steps forward.) + decltype(m.First()) it; + IKeyValuePair kvp1[1]; + IKeyValuePair kvp2[1]; + race(collection_action::at, [&] { - map.Insert(j + off, conditional_box(j)); - } - }); - } - - for (auto& t : threads) - { - t.join(); - } - - REQUIRE(map.Size() == (size * threadCount)); - - // Since we know that the underlying collection type is std::map, the keys should be ordered - int expect = 0; - for (auto&& pair : map) - { - REQUIRE(pair.Key() == expect++); - } -} - -template -struct exclusive_map : - map_base, K, V>, - implements, IMap, IMapView, IIterable>> -{ - std::map container; - mutable std::shared_mutex mutex; - - auto& get_container() noexcept - { - return container; - } - - auto& get_container() const noexcept - { - return container; - } - - // It is not safe to recursively acquire an SRWLOCK, even in shared mode, however this is unchecked by the SRWLOCK - // implementation. Using a vector that only performs exclusive operations is the simplest way to validate that - // the implementation does not attempt to recursively acquire the mutex. - template - auto perform_exclusive(Func&& fn) const - { - // Exceptions are better than deadlocks... - REQUIRE(mutex.try_lock()); - std::lock_guard guard(mutex, std::adopt_lock); - return fn(); - } -}; - -struct map_deadlock_object : implements> -{ - int m_value; - exclusive_map>* m_vector; - - map_deadlock_object(int value, exclusive_map>* vector) : - m_value(value), - m_vector(vector) - { - } - - ~map_deadlock_object() - { - // NOTE: This will crash on failure, but that's better than actually deadlocking - REQUIRE(m_vector->mutex.try_lock()); - m_vector->mutex.unlock(); + it = m.First(); + CHECK(it.GetMany(kvp1) == 1); + }, [&] + { + CHECK(it.GetMany(kvp2) == 1); + }); + CHECK(kvp1[0].Key() != kvp2[0].Key()); + } } - int Value() const noexcept + void deadlock_test() { - return m_value; + auto m = make_threaded_map(concurrency_checked_map()); + m.Insert(0, make>>(m)); + auto task = [](auto m)-> IAsyncAction + { + co_await resume_background(); + m.Remove(0); + }(m); + auto status = task.wait_for(std::chrono::milliseconds(DEADLOCK_TIMEOUT)); + REQUIRE(status == AsyncStatus::Completed); } -}; - -static void deadlock_test() -{ - auto map = make_self>>(); - - map->Insert(0, make(0, map.get())); - map->Insert(1, make(1, map.get())); - REQUIRE(map->Size() == 2); - REQUIRE(map->HasKey(0)); - REQUIRE(!map->HasKey(2)); - REQUIRE(map->Lookup(0).Value() == 0); - map->Remove(0); - REQUIRE(map->Size() == 1); - map->Clear(); - REQUIRE(map->Size() == 0); - - map->Insert(0, make(0, map.get())); - map->Insert(1, make(1, map.get())); - auto view = map->GetView(); - REQUIRE(view.Size() == 2); - REQUIRE(view.HasKey(0)); - REQUIRE(view.Lookup(1).Value() == 1); - - auto itr = map->First(); - REQUIRE(itr.HasCurrent()); - REQUIRE(itr.Current().Key() == 0); - REQUIRE(itr.MoveNext()); - REQUIRE(!itr.MoveNext()); - REQUIRE(!itr.HasCurrent()); } TEST_CASE("multi_threaded_map") { - test_single_reader_single_writer(multi_threaded_map()); - test_single_reader_single_writer(multi_threaded_map()); - - test_iterator_invalidation(multi_threaded_map()); - test_iterator_invalidation(multi_threaded_map()); - - test_concurrent_iteration(multi_threaded_map()); - test_concurrent_iteration(multi_threaded_map()); - - test_multi_writer(multi_threaded_map()); - test_multi_writer(multi_threaded_map()); - + test_map_concurrency(); + test_map_concurrency(); deadlock_test(); } TEST_CASE("multi_threaded_observable_map") { - test_single_reader_single_writer(multi_threaded_observable_map()); - test_single_reader_single_writer(multi_threaded_observable_map()); - - test_iterator_invalidation(multi_threaded_observable_map()); - test_iterator_invalidation(multi_threaded_observable_map()); - - test_concurrent_iteration(multi_threaded_observable_map()); - test_concurrent_iteration(multi_threaded_observable_map()); - test_multi_writer(multi_threaded_observable_map()); - test_multi_writer(multi_threaded_observable_map()); + test_map_concurrency(); + test_map_concurrency(); } diff --git a/test/test/multi_threaded_vector.cpp b/test/test/multi_threaded_vector.cpp index ab415101a..24b27d0bc 100644 --- a/test/test/multi_threaded_vector.cpp +++ b/test/test/multi_threaded_vector.cpp @@ -1,505 +1,468 @@ #include "pch.h" -#include -#include - #include "multi_threaded_common.h" using namespace winrt; using namespace Windows::Foundation; using namespace Windows::Foundation::Collections; +using namespace concurrent_collections; -// Vector correctness tests exist elsewhere. These tests are strictly geared toward testing multi threaded functionality -// -// Be careful with use of REQUIRE. -// -// 1. REQUIRE is not concurrency-safe. Don't call it from two threads simultaneously. -// 2. The number of calls to REQUIRE should be the consistent in the face of nondeterminism. -// This makes the "(x assertions in y test cases)" consistent. -// -// If you need to check something from a background thread, or where the number -// of iterations can vary from run to run, use winrt::check_bool, which still -// fails the test but doesn't run afoul of REQUIRE's limitations. - -template -static void test_single_reader_single_writer(IVector const& v) -{ - static constexpr int final_size = 10000; +// Vector correctness tests exist elsewhere. These tests are strictly geared toward testing multi threaded functionality. - // Append / Size / GetAt / IndexOf +namespace +{ + // We use a customized container that mimics std::vector and which + // validates that C++ concurrency rules are observed. + // C++ rules for library types are that concurrent use of const methods is allowed, + // but no method call may be concurrent with a non-const method. (Const methods may + // be "shared", but non-const methods are "exclusive".) + // + // NOTE! As the C++/WinRT implementation changes, you may need to add additional members + // to our fake vector and vector iterator classes. + // + // The regular single_threaded_vector and multi_threaded_vector functions requires std::vector, + // so we bypass that method and go directly to input_vector, which takes an arbitrary container + // that acts vector-like. + + enum class VectorKind { - unique_thread t([&] + IVector, + IObservableVector, + IObservableVectorAsInspectable, + }; + + // Change the next line to "#if 0" to use a single-threaded vector and confirm that every test fails. + // The scenarios use "CHECK" instead of "REQUIRE" so that they continue running even on failure. + // That way, you can just step through the entire test and confirm that something bad happens + // at each scenario. +#if 1 + template + using custom_threaded_vector = winrt::impl::multi_threaded_vector; + + template + using custom_inspectable_observable_vector = winrt::impl::multi_threaded_inspectable_observable_vector; + + template + using custom_convertible_observable_vector = winrt::impl::multi_threaded_convertible_observable_vector; + +#else + template + using custom_threaded_vector = winrt::impl::input_vector; + + template + using custom_inspectable_observable_vector = winrt::impl::inspectable_observable_vector; + + template + using custom_convertible_observable_vector = winrt::impl::convertible_observable_vector; +#endif + + template + auto make_threaded_vector(Container&& values) + { + using T = typename Container::value_type; + if constexpr (kind == VectorKind::IVector) { - for (int i = 0; i < final_size; ++i) - { - v.Append(conditional_box(i)); - std::this_thread::yield(); - } - }); - - while (true) + return static_cast>(winrt::make>(std::move(values))); + } + else { - auto beginSize = v.Size(); - int i = 0; - for (; i < final_size; ++i) + IObservableVector vector; + if constexpr (std::is_same_v) { - if (static_cast(i) >= v.Size()) - { - check_bool(static_cast(i) >= beginSize); - break; - } - - check_bool(conditional_unbox(v.GetAt(i)) == i); - - if constexpr (std::is_same_v) - { - uint32_t index; - check_bool(v.IndexOf(i, index)); - check_bool(index == static_cast(i)); - } + vector = make>(std::move(values)); } - - if (i == final_size) + else { - break; + vector = make>(std::move(values)); } - check_bool(beginSize != final_size); - } - } - - // InsertAt / Size / GetMany - { - v.Clear(); - unique_thread t([&] - { - for (int i = 0; i < final_size; ++i) + if constexpr (kind == VectorKind::IObservableVector) { - v.InsertAt(0, conditional_box(i)); - std::this_thread::yield(); + return vector; } - }); - - T vals[100]; - while (v.Size() < final_size) - { - auto len = v.GetMany(0, vals); - for (uint32_t i = 1; i < len; ++i) + else { - check_bool(conditional_unbox(vals[i]) == (conditional_unbox(vals[i - 1]) - 1)); + return vector.as>(); } } } - // RemoveAt / Size / GetMany +#pragma region vector wrapper + // Add more wrapper methods as necessary. + template> + struct concurrency_checked_vector : private std::vector, concurrency_guard { - unique_thread t([&] + using inner = typename concurrency_checked_vector::vector; + using value_type = typename inner::value_type; + using allocator_type = typename inner::allocator_type; + using size_type = typename inner::size_type; + using difference_type = typename inner::difference_type; + using reference = typename inner::reference; + using const_reference = typename inner::const_reference; + using iterator = concurrency_checked_random_access_iterator; + using const_iterator = concurrency_checked_random_access_iterator; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + static_assert(!std::is_same_v, "Has never been tested with bool."); + + concurrency_checked_vector() = default; + concurrency_checked_vector(concurrency_checked_vector&& other) = default; + + iterator begin() { - while (v.Size() != 0) - { - v.RemoveAt(0); - std::this_thread::yield(); - } - }); + auto guard = concurrency_guard::lock_nonconst(); + return { this, inner::begin() }; + } - T vals[100]; - while (v.Size() > 0) + const_iterator begin() const { - auto len = v.GetMany(0, vals); - for (uint32_t i = 1; i < len; ++i) - { - check_bool(conditional_unbox(vals[i]) == (conditional_unbox(vals[i - 1]) - 1)); - } + auto guard = concurrency_guard::lock_const(); + return { this, inner::begin() }; } - } - // SetAt / GetMany - { - T vals[100]; - for (int i = 0; i < std::size(vals); ++i) + iterator end() { - v.Append(conditional_box(i)); + auto guard = concurrency_guard::lock_nonconst(); + return { this, inner::end() }; } - static constexpr int iterations = 1000; - unique_thread t([&] + const_iterator end() const { - for (int i = 1; i <= iterations; ++i) - { - for (int j = 0; j < std::size(vals); ++j) - { - v.SetAt(j, conditional_box(j + i)); - } - std::this_thread::yield(); - } - }); + auto guard = concurrency_guard::lock_const(); + return { this, inner::end() }; + } - while (conditional_unbox(v.GetAt(0)) != iterations) + bool empty() const { - v.GetMany(0, vals); - int jumps = 0; - for (int i = 1; i < std::size(vals); ++i) - { - auto prev = conditional_unbox(vals[i - 1]); - auto curr = conditional_unbox(vals[i]); - if (prev == curr) - { - ++jumps; - } - else - { - check_bool(curr == (prev + 1)); - } - } - check_bool(jumps <= 1); + auto guard = concurrency_guard::lock_const(); + return inner::empty(); } - } - // Append / ReplaceAll / GetMany - { - static constexpr int size = 10; - v.Clear(); - for (int i = 0; i < size; ++i) + reference back() { - v.Append(conditional_box(i)); + auto guard = concurrency_guard::lock_nonconst(); + return inner::back(); } - static constexpr int iterations = 1000; - unique_thread t([&] + const_reference back() const { - T newVals[size]; - for (int i = 1; i <= iterations; ++i) - { - for (int j = 0; j < size; ++j) - { - newVals[j] = conditional_box(i + j); - } - v.ReplaceAll(newVals); - } - }); + auto guard = concurrency_guard::lock_const(); + return inner::back(); + } - T vals[size]; - do + void pop_back() { - auto len = v.GetMany(0, vals); - check_bool(len == size); - for (int i = 1; i < size; ++i) - { - check_bool(conditional_unbox(vals[i]) == (conditional_unbox(vals[i - 1]) + 1)); - } + auto guard = concurrency_guard::lock_nonconst(); + inner::pop_back(); } - while (conditional_unbox(vals[0]) != iterations); - } -} - -template -static void test_iterator_invalidation(IVector const& v) -{ - static constexpr uint32_t final_size = 100000; - // Append / Size / First / HasCurrent / Current / MoveNext - unique_thread t([&] - { - for (int i = 0; i < final_size; ++i) + void push_back(T const& value) { - v.Append(conditional_box(i)); - std::this_thread::yield(); + auto guard = concurrency_guard::lock_nonconst(); + concurrency_guard::call_hook(collection_action::push_back); + inner::push_back(value); } - }); - int exceptionCount = 0; - bool forceExit = false; - while (!forceExit) - { - forceExit = v.Size() == final_size; - try + void push_back(T&& value) { - int expect = 0; - for (auto itr = v.First(); itr.HasCurrent(); itr.MoveNext()) - { - auto val = conditional_unbox(itr.Current()); - check_bool(val == expect++); - } - - if (expect == final_size) - { - break; - } + auto guard = concurrency_guard::lock_nonconst(); + concurrency_guard::call_hook(collection_action::push_back); + inner::push_back(std::move(value)); } - catch (hresult_changed_state const&) + + size_type size() const { - ++exceptionCount; + auto guard = concurrency_guard::lock_const(); + return inner::size(); } - check_bool(!forceExit); - } + iterator erase(const_iterator pos) + { + auto guard = concurrency_guard::lock_nonconst(); + concurrency_guard::call_hook(collection_action::erase); + return { this, inner::erase(pos) }; + } - // Since the insert thread yields after each insertion, this should really be in the thousands - REQUIRE(exceptionCount > 50); -} + iterator insert(const_iterator pos, T const& value) + { + auto guard = concurrency_guard::lock_nonconst(); + concurrency_guard::call_hook(collection_action::insert); + return { this, inner::insert(pos, value) }; + } -template -static void test_concurrent_iteration(IVector const& v) -{ - // Large enough size that all threads should have enough time to spin up - static constexpr uint32_t size = 100000; + iterator insert(const_iterator pos, T&& value) + { + auto guard = concurrency_guard::lock_nonconst(); + concurrency_guard::call_hook(collection_action::insert); + return { this, inner::insert(pos, std::move(value)) }; + } - // Append / Current / MoveNext / HasCurrent - { - for (int i = 0; i < size; ++i) + reference operator[](size_type pos) { - v.Append(conditional_box(i)); + return at(pos); } - auto itr = v.First(); - unique_thread threads[2]; - int increments[std::size(threads)] = {}; - for (int i = 0; i < std::size(threads); ++i) + reference at(size_type pos) { - threads[i] = unique_thread([&itr, &increments, i] - { - int last = -1; - while (true) - { - try - { - // NOTE: Because there is no atomic "get and increment" function on IIterator, the best we can do is - // validate that we're getting valid increasing values, e.g. as opposed to validating that we read - // all unique values. - auto val = conditional_unbox(itr.Current()); - check_bool(val > last); - check_bool(val < size); - last = val; - if (!itr.MoveNext()) - { - break; - } - - // MoveNext is the only synchronized operation that has a predictable side effect we can validate - ++increments[i]; - } - catch (hresult_error const&) - { - // There's no "get if" function, so concurrent increment past the end is always possible... - check_bool(!itr.HasCurrent()); - break; - } - } - }); + auto guard = concurrency_guard::lock_nonconst(); + concurrency_guard::call_hook(collection_action::at); + return inner::at(pos); } - for (auto& t : threads) + void clear() { - t.join(); + auto guard = concurrency_guard::lock_nonconst(); + return inner::clear(); } - REQUIRE(!itr.HasCurrent()); + template + void assign(InputIt first, InputIt last) + { + auto guard = concurrency_guard::lock_nonconst(); + return inner::assign(first, last); + } - auto totalIncrements = std::accumulate(std::begin(increments), std::end(increments), 0); - REQUIRE(totalIncrements == (size - 1)); - } + void reserve(size_type capacity) { + auto guard = concurrency_guard::lock_nonconst(); + return inner::reserve(capacity); + } - // HasCurrent / GetMany - { - auto itr = v.First(); - unique_thread threads[2]; - int totals[std::size(threads)] = {}; - for (int i = 0; i < std::size(threads); ++i) + void swap(concurrency_checked_vector& other) { - threads[i] = unique_thread([&itr, &totals, i]() - { - T vals[10]; - while (itr.HasCurrent()) - { - // Unlike 'Current', 'GetMany' _is_ atomic in regards to read+increment - auto len = itr.GetMany(vals); - totals[i] += std::accumulate(vals, vals + len, 0, [](int curr, T const& next) { return curr + conditional_unbox(next); }); - } - }); + auto guard = concurrency_guard::lock_nonconst(); + inner::swap(other); } - for (auto& t : threads) + template + decltype(auto) dereference_iterator(Iterator const& it) const { - t.join(); + auto guard = concurrency_guard::lock_const(); + concurrency_guard::call_hook(collection_action::at); + return *it; } - // sum(i = 1 -> N){i} = N * (N + 1) / 2 - auto total = std::accumulate(std::begin(totals), std::end(totals), 0); - REQUIRE(total == (size * (size - 1) / 2)); - } -} + operator array_view() + { + auto guard = concurrency_guard::lock_nonconst(); + return { inner::data(), static_cast(inner::size()) }; + } -template -static void test_multi_writer(IVector const& v) -{ - // Large enough that several threads should be executing concurrently - static constexpr uint32_t size = 10000; - static constexpr size_t threadCount = 8; + operator array_view() const + { + auto guard = concurrency_guard::lock_const(); + return { inner::data(), static_cast(inner::size()) }; + } + }; +#pragma endregion - unique_thread threads[threadCount]; - for (auto& t : threads) + template + void test_vector_concurrency() { - t = unique_thread([&v] + auto raw = concurrency_checked_vector(); + auto hook = raw.hook; + + // Convert the raw_vector into the desired Windows Runtime vector interface. + auto v = make_threaded_vector(std::move(raw)); + + auto race = [&](collection_action action, auto&& background, auto&& foreground) + { + // Vector initial contents are { 1, 2, 3 }. + v.ReplaceAll({ conditional_box(1), conditional_box(2), conditional_box(3) }); + hook->race(action, background, foreground); + }; + + // Verify that Append does not run concurrently with GetAt(). + race(collection_action::push_back, [&] + { + v.Append(conditional_box(42)); + }, [&] + { + CHECK(conditional_unbox(v.GetAt(3)) == 42); + }); + + // Verify that Append does not run concurrently with Size(). + race(collection_action::push_back, [&] + { + v.Append(conditional_box(42)); + }, [&] { - for (int i = 0; i < size; ++i) + CHECK(v.Size() == 4); + }); + + // Verify that Append does not run concurrently with IndexOf(). + race(collection_action::push_back, [&] + { + v.Append(conditional_box(42)); + }, [&] + { + uint32_t index; + bool found = v.IndexOf(conditional_box(3), index); + if constexpr (std::is_same_v) { - v.Append(conditional_box(i)); + CHECK(found); + CHECK(index == 2); + } + else + { + // Boxed integers do not compare equal even if the values are the same. + CHECK(!found); } }); - } - - for (auto& t : threads) - { - t.join(); - } - REQUIRE(v.Size() == (size * threadCount)); + // Verify that Append does not run concurrently with another Append(). + race(collection_action::push_back, [&] + { + v.Append(conditional_box(42)); + }, [&] + { + v.Append(conditional_box(43)); + }); - // sum(i = 1 -> N){i} = N * (N + 1) / 2 - auto sum = std::accumulate(begin(v), end(v), 0, [](int curr, T const& next) { return curr + conditional_unbox(next); }); - REQUIRE(sum == ((threadCount * (size - 1) * size) / 2)); -} + // Verify that Append does not run concurrently with ReplaceAll(). + race(collection_action::push_back, [&] + { + v.Append(conditional_box(42)); + }, [&] + { + v.ReplaceAll({ conditional_box(1), conditional_box(2) }); + }); -template -struct exclusive_vector : - vector_base, T>, - implements, IVector, IVectorView, IIterable> -{ - std::vector container; - mutable std::shared_mutex mutex; + // Verify that Append does not run concurrently with GetMany(). + race(collection_action::push_back, [&] + { + v.Append(conditional_box(42)); + }, [&] + { + T values[10]; + CHECK(v.GetMany(0, values) == 4); + CHECK(conditional_unbox(values[0]) == 1); + CHECK(conditional_unbox(values[1]) == 2); + CHECK(conditional_unbox(values[2]) == 3); + CHECK(conditional_unbox(values[3]) == 42); + }); - auto& get_container() noexcept - { - return container; - } + // Verify that InsertAt does not run concurrently with GetAt(). + race(collection_action::insert, [&] + { + v.InsertAt(1, conditional_box(42)); + }, [&] + { + CHECK(conditional_unbox(v.GetAt(1)) == 42); + }); - auto& get_container() const noexcept - { - return container; - } + // Verify that InsertAt does not run concurrently with Size(). + race(collection_action::insert, [&] + { + v.InsertAt(1, conditional_box(42)); + }, [&] + { + CHECK(v.Size() == 4); + }); - // It is not safe to recursively acquire an SRWLOCK, even in shared mode, however this is unchecked by the SRWLOCK - // implementation. Using a vector that only performs exclusive operations is the simplest way to validate that - // the implementation does not attempt to recursively acquire the mutex. - template - auto perform_exclusive(Func&& fn) const - { - // Exceptions are better than deadlocks... - REQUIRE(mutex.try_lock()); - std::lock_guard guard(mutex, std::adopt_lock); - return fn(); - } -}; + // Verify that InsertAt does not run concurrently with GetMany(). + race(collection_action::insert, [&] + { + v.InsertAt(1, conditional_box(42)); + }, [&] + { + T values[10]; + CHECK(v.GetMany(0, values) == 4); + CHECK(conditional_unbox(values[0]) == 1); + CHECK(conditional_unbox(values[1]) == 42); + CHECK(conditional_unbox(values[2]) == 2); + CHECK(conditional_unbox(values[3]) == 3); + }); -struct vector_deadlock_object : implements> -{ - int m_value; - exclusive_vector>* m_vector; + // Verify that RemoveAt does not run concurrently with GetAt(). + race(collection_action::erase, [&] + { + v.RemoveAt(1); + }, [&] + { + CHECK(conditional_unbox(v.GetAt(1)) == 3); + }); - vector_deadlock_object(int value, exclusive_vector>* vector) : - m_value(value), - m_vector(vector) - { - } + // Verify that RemoveAt does not run concurrently with Size(). + race(collection_action::erase, [&] + { + v.RemoveAt(1); + }, [&] + { + CHECK(v.Size() == 2); + }); - ~vector_deadlock_object() - { - // NOTE: This will crash on failure, but that's better than actually deadlocking - REQUIRE(m_vector->mutex.try_lock()); - m_vector->mutex.unlock(); - } + // Verify that SetAt does not run concurrently with GetAt(). + race(collection_action::at, [&] + { + v.SetAt(1, conditional_box(42)); + }, [&] + { + CHECK(conditional_unbox(v.GetAt(1)) == 42); + }); - int Value() const noexcept - { - return m_value; - } -}; + // Iterator invalidation tests are a little different because we perform + // the mutation from the foreground thread after the read operation + // has begun on the background thread. + { + // Verify that iterator invalidation doesn't race against + // iterator use. + decltype(v.First()) it; + T t; + race(collection_action::at, [&] + { + it = v.First(); + t = it.Current(); + }, [&] + { + v.InsertAt(0, conditional_box(42)); + }); + CHECK(conditional_unbox(t) == 1); + } -static void deadlock_test() -{ - auto v = make_self>>(); - - v->Append(make(42, v.get())); - v->InsertAt(0, make(8, v.get())); - REQUIRE(v->Size() == 2); - REQUIRE(v->GetAt(0).Value() == 8); - uint32_t index; - REQUIRE(v->IndexOf(42, index)); - REQUIRE(index == 1); - - v->ReplaceAll({ make(1, v.get()), make(2, v.get()), make(3, v.get()) }); - v->SetAt(1, make(4, v.get())); - { - IReference vals[5]; - REQUIRE(v->GetMany(0, vals) == 3); - REQUIRE(vals[0].Value() == 1); - REQUIRE(vals[1].Value() == 4); - REQUIRE(vals[2].Value() == 3); + { + // Verify that concurrent iteration works via GetMany(), which is atomic. + // (Current + MoveNext is non-atomic and can result in two threads + // both grabbing the same Current and then moving two steps forward.) + decltype(v.First()) it; + T t1[1]; + T t2[1]; + race(collection_action::at, [&] + { + it = v.First(); + CHECK(it.GetMany(t1) == 1); + }, [&] + { + CHECK(it.GetMany(t2) == 1); + }); + CHECK(conditional_unbox(t1[0]) != conditional_unbox(t2[0])); + } } - v->RemoveAt(1); - REQUIRE(v->GetAt(1).Value() == 3); - v->RemoveAtEnd(); - REQUIRE(v->GetAt(0).Value() == 1); - v->Clear(); - REQUIRE(v->Size() == 0); - - v->ReplaceAll({ make(1, v.get()), make(2, v.get()), make(3, v.get()) }); - auto view = v->GetView(); - REQUIRE(view.Size() == 3); - REQUIRE(view.GetAt(0).Value() == 1); - + void deadlock_test() { - IReference vals[5]; - REQUIRE(view.GetMany(0, vals) == 3); - REQUIRE(vals[0].Value() == 1); - REQUIRE(vals[1].Value() == 2); - REQUIRE(vals[2].Value() == 3); + auto v = make_threaded_vector(concurrency_checked_vector()); + v.Append(make>>(v)); + auto task = [](auto v)-> IAsyncAction + { + co_await resume_background(); + v.RemoveAtEnd(); + }(v); + auto status = task.wait_for(std::chrono::milliseconds(DEADLOCK_TIMEOUT)); + REQUIRE(status == AsyncStatus::Completed); } - - REQUIRE(view.IndexOf(2, index)); - REQUIRE(index == 1); - - auto itr = v->First(); - REQUIRE(itr.HasCurrent()); - REQUIRE(itr.Current().Value() == 1); - REQUIRE(itr.MoveNext()); - REQUIRE(itr.MoveNext()); - REQUIRE(!itr.MoveNext()); - REQUIRE(!itr.HasCurrent()); } TEST_CASE("multi_threaded_vector") { - test_single_reader_single_writer(multi_threaded_vector()); - test_single_reader_single_writer(multi_threaded_vector()); - - test_iterator_invalidation(multi_threaded_vector()); - test_iterator_invalidation(multi_threaded_vector()); - - test_concurrent_iteration(multi_threaded_vector()); - test_concurrent_iteration(multi_threaded_vector()); - - test_multi_writer(multi_threaded_vector()); - test_multi_writer(multi_threaded_vector()); + test_vector_concurrency(); + test_vector_concurrency(); deadlock_test(); } TEST_CASE("multi_threaded_observable_vector") { - test_single_reader_single_writer(multi_threaded_observable_vector()); - test_single_reader_single_writer(multi_threaded_observable_vector()); - test_single_reader_single_writer(multi_threaded_observable_vector().as>()); - - test_iterator_invalidation(multi_threaded_observable_vector()); - test_iterator_invalidation(multi_threaded_observable_vector()); - test_iterator_invalidation(multi_threaded_observable_vector().as>()); - - test_concurrent_iteration(multi_threaded_observable_vector()); - test_concurrent_iteration(multi_threaded_observable_vector()); - test_concurrent_iteration(multi_threaded_observable_vector().as>()); - - test_multi_writer(multi_threaded_observable_vector()); - test_multi_writer(multi_threaded_observable_vector()); - test_multi_writer(multi_threaded_observable_vector().as>()); + test_vector_concurrency(); + test_vector_concurrency(); + test_vector_concurrency(); }