From 1ad7559bf1f661458f72fed9430d7c54ad6498d1 Mon Sep 17 00:00:00 2001 From: Sergey Yablokov Date: Fri, 25 Apr 2025 09:11:54 +0200 Subject: [PATCH] SmallVector improvements --- internal/SmallVector.h | 120 +++++++++++++++------ internal/Vk/ContextVK.cpp | 2 +- tests/CMakeLists.txt | 1 + tests/main.cpp | 2 + tests/test_small_vector.cpp | 209 ++++++++++++++++++++++++++++++++++++ 5 files changed, 302 insertions(+), 32 deletions(-) create mode 100644 tests/test_small_vector.cpp diff --git a/internal/SmallVector.h b/internal/SmallVector.h index 3b57b5fd..352b43ec 100644 --- a/internal/SmallVector.h +++ b/internal/SmallVector.h @@ -21,7 +21,7 @@ template > cla static const uint32_t OwnerBit = (1u << (8u * sizeof(uint32_t) - 1u)); static const uint32_t CapacityMask = ~OwnerBit; - protected: + protected: SmallVectorImpl(T *begin, T *end, const uint32_t capacity, const Allocator &alloc) : Allocator(alloc), begin_(begin), size_(uint32_t(end - begin)), capacity_(capacity) {} @@ -47,7 +47,7 @@ template > cla reserve(new_capacity); } - public: + public: using iterator = T *; using const_iterator = const T *; @@ -118,10 +118,45 @@ template > cla operator Span() const { return Span(data(), size()); } + bool operator==(const SmallVectorImpl &rhs) const { + if (size_ != rhs.size_) { + return false; + } + bool eq = true; + for (uint32_t i = 0; i < size_ && eq; ++i) { + eq &= begin_[i] == rhs.begin_[i]; + } + return eq; + } + bool operator!=(const SmallVectorImpl &rhs) const { + if (size_ != rhs.size_) { + return true; + } + bool neq = false; + for (uint32_t i = 0; i < size_ && !neq; ++i) { + neq |= begin_[i] != rhs.begin_[i]; + } + return neq; + } + bool operator<(const SmallVectorImpl &rhs) const { + return std::lexicographical_compare(begin(), end(), rhs.begin(), rhs.end()); + } + bool operator<=(const SmallVectorImpl &rhs) const { + return !std::lexicographical_compare(rhs.begin(), rhs.end(), begin(), end()); + } + bool operator>(const SmallVectorImpl &rhs) const { + return std::lexicographical_compare(rhs.begin(), rhs.end(), begin(), end()); + } + bool operator>=(const SmallVectorImpl &rhs) const { + return !std::lexicographical_compare(begin(), end(), rhs.begin(), rhs.end()); + } + const T *cdata() const noexcept { return begin_; } const T *data() const noexcept { return begin_; } - const T *begin() const noexcept { return begin_; } - const T *end() const noexcept { return begin_ + size_; } + const_iterator begin() const noexcept { return begin_; } + const_iterator end() const noexcept { return begin_ + size_; } + const_iterator cbegin() const noexcept { return begin_; } + const_iterator cend() const noexcept { return begin_ + size_; } T *data() noexcept { return begin_; } iterator begin() noexcept { return begin_; } @@ -249,18 +284,45 @@ template > cla ensure_reserved(size_ + 1); pos = begin_ + off; - iterator move_dst = begin_ + size_, move_src = begin_ + size_ - 1; - while (move_dst != pos) { - (*move_dst) = std::move(*move_src); + iterator move_src = begin_ + size_ - 1, move_dst = move_src + 1; + while (move_src != pos - 1) { + new (move_dst) T(std::move(*move_src)); + move_src->~T(); --move_dst; --move_src; } + new (pos) T(value); ++size_; - new (move_dst) T(value); - return move_dst; + return pos; + } + + iterator insert(iterator pos, iterator beg, iterator end) { + assert(pos >= begin_ && pos <= begin_ + size_); + + const uint32_t count = uint32_t(end - beg); + const uint32_t off = uint32_t(pos - begin_); + ensure_reserved(size_ + count); + pos = begin_ + off; + + iterator move_src = begin_ + size_ - 1, move_dst = move_src + count; + while (move_src != pos - 1) { + new (move_dst) T(std::move(*move_src)); + move_src->~T(); + + --move_dst; + --move_src; + } + + move_dst = pos; + while (move_dst != pos + count) { + new (move_dst++) T(*beg++); + } + size_ += count; + + return pos; } iterator erase(iterator pos) { @@ -295,42 +357,33 @@ template > cla return move_dst; } + void assign(const uint32_t count, const T &val) { + clear(); + reserve(count); + for (uint32_t i = 0; i < count; ++i) { + push_back(val); + } + } + template void assign(const InputIt first, const InputIt last) { clear(); + reserve(uint32_t(last - first)); for (InputIt it = first; it != last; ++it) { push_back(*it); } } }; -template > -bool operator==(const SmallVectorImpl &lhs, const SmallVectorImpl &rhs) { - if (lhs.size() != rhs.size()) { - return false; - } - for (const T *lhs_it = lhs.begin(), *rhs_it = rhs.begin(); lhs_it != lhs.end(); ++lhs_it, ++rhs_it) { - if (*lhs_it != *rhs_it) { - return false; - } - } - return true; -} - -template > -bool operator!=(const SmallVectorImpl &lhs, const SmallVectorImpl &rhs) { - return operator==(lhs, rhs); -} - template > class SmallVector : public SmallVectorImpl { alignas(AlignmentOfT) char buffer_[sizeof(T) * N]; - public: + public: SmallVector(const Allocator &alloc = Allocator()) // NOLINT : SmallVectorImpl((T *)buffer_, (T *)buffer_, N, alloc) {} - SmallVector(uint32_t initial_size, const T &val = T(), const Allocator &alloc = Allocator()) // NOLINT + explicit SmallVector(const uint32_t size, const T &val = T(), const Allocator &alloc = Allocator()) // NOLINT : SmallVectorImpl((T *)buffer_, (T *)buffer_, N, alloc) { - SmallVectorImpl::resize(initial_size, val); + SmallVectorImpl::resize(size, val); } SmallVector(const SmallVector &rhs) // NOLINT : SmallVectorImpl((T *)buffer_, (T *)buffer_, N, rhs.alloc()) { @@ -349,9 +402,14 @@ class SmallVector : public SmallVectorImpl { SmallVectorImpl::operator=(std::move(rhs)); } + template + SmallVector(InputIt beg, InputIt end, const Allocator &alloc = Allocator()) + : SmallVectorImpl((T *)buffer_, (T *)buffer_, N, alloc) { + SmallVectorImpl::assign(beg, end); + } + SmallVector(std::initializer_list l, const Allocator &alloc = Allocator()) : SmallVectorImpl((T *)buffer_, (T *)buffer_, N, alloc) { - SmallVectorImpl::reserve(uint32_t(l.size())); SmallVectorImpl::assign(l.begin(), l.end()); } diff --git a/internal/Vk/ContextVK.cpp b/internal/Vk/ContextVK.cpp index 7ead2ee3..c0703ce3 100644 --- a/internal/Vk/ContextVK.cpp +++ b/internal/Vk/ContextVK.cpp @@ -658,7 +658,7 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD api.vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(physical_device, &props_count, nullptr); SmallVector coop_matrix_props( - props_count, {VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR}); + props_count, VkCooperativeMatrixPropertiesKHR{VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR}); api.vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(physical_device, &props_count, coop_matrix_props.data()); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 5f927748..746ce11c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -26,6 +26,7 @@ add_executable(test_Ray main.cpp test_simd_avx512.cpp test_simd_sse41.cpp test_simd.ipp + test_small_vector.cpp test_span.cpp test_sparse_storage.cpp test_tex_storage.cpp diff --git a/tests/main.cpp b/tests/main.cpp index df79bd63..da7e79f9 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -17,6 +17,7 @@ void test_huffman(); void test_inflate(); void test_scope_exit(); void test_freelist_alloc(); +void test_small_vector(); void test_span(); void test_sparse_storage(); void test_tex_storage(); @@ -204,6 +205,7 @@ int main(int argc, char *argv[]) { test_huffman(); test_inflate(); test_scope_exit(); + test_small_vector(); test_span(); test_sparse_storage(); test_tex_storage(); diff --git a/tests/test_small_vector.cpp b/tests/test_small_vector.cpp new file mode 100644 index 00000000..a8fa43e2 --- /dev/null +++ b/tests/test_small_vector.cpp @@ -0,0 +1,209 @@ +#include "test_common.h" + +#include + +#include "../internal/SmallVector.h" + +void test_small_vector() { + using namespace Ray; + + printf("Test small_vector | "); + + static_assert(sizeof(SmallVectorImpl) <= 16, "!"); + + { // basic usage with trivial type + SmallVector vec; + + for (int i = 0; i < 8; i++) { + vec.push_back(i); + } + for (int i = 8; i < 16; i++) { + vec.emplace_back(i); + } + + require(vec.empty() == false); + require(vec.size() == 16); + require(vec.capacity() == 16); + require(vec.is_on_heap() == false); + + for (int i = 0; i < 16; i++) { + require(vec[i] == i); + } + require(vec.back() == 15); + + vec.push_back(42); + + require(vec.empty() == false); + require(vec.size() == 17); + require(vec.is_on_heap() == true); + + for (int i = 0; i < 16; i++) { + require(vec[i] == i); + } + require(vec.back() == 42); + + vec.insert(vec.begin(), -42); + + require(vec.empty() == false); + require(vec.size() == 18); + require(vec.is_on_heap() == true); + + require(vec[0] == -42); + require(vec[1] == 0); + require(vec[2] == 1); + require(vec[3] == 2); + require(vec[4] == 3); + require(vec[5] == 4); + require(vec[6] == 5); + require(vec[7] == 6); + require(vec[8] == 7); + require(vec[9] == 8); + require(vec[10] == 9); + require(vec[11] == 10); + require(vec[12] == 11); + require(vec[13] == 12); + require(vec[14] == 13); + require(vec[15] == 14); + require(vec[16] == 15); + require(vec[17] == 42); + } + + { // basic usage with complicated type + SmallVector vec; + + for (int i = 0; i < 8; i++) { + vec.push_back(std::to_string(i)); + } + for (int i = 8; i < 16; i++) { + vec.emplace_back(std::to_string(i)); + } + + require(vec.empty() == false); + require(vec.size() == 16); + require(vec.capacity() == 16); + require(vec.is_on_heap() == false); + + for (int i = 0; i < 16; i++) { + require(vec[i] == std::to_string(i)); + } + require(vec.back() == "15"); + + vec.push_back("42"); + + require(vec.empty() == false); + require(vec.size() == 17); + require(vec.is_on_heap() == true); + + for (int i = 0; i < 16; i++) { + require(vec[i] == std::to_string(i)); + } + require(vec.back() == "42"); + + vec.insert(vec.begin(), "-42"); + + require(vec.empty() == false); + require(vec.size() == 18); + require(vec.is_on_heap() == true); + + require(vec[0] == "-42"); + require(vec[1] == "0"); + require(vec[2] == "1"); + require(vec[3] == "2"); + require(vec[4] == "3"); + require(vec[5] == "4"); + require(vec[6] == "5"); + require(vec[7] == "6"); + require(vec[8] == "7"); + require(vec[9] == "8"); + require(vec[10] == "9"); + require(vec[11] == "10"); + require(vec[12] == "11"); + require(vec[13] == "12"); + require(vec[14] == "13"); + require(vec[15] == "14"); + require(vec[16] == "15"); + require(vec[17] == "42"); + } + + { // usage with custom type + struct AAA { + char more_data[16] = {}; + int data; + + explicit AAA(int _data) : data(_data) {} + + AAA(const AAA &rhs) = delete; + AAA(AAA &&rhs) = default; + AAA &operator=(const AAA &rhs) = delete; + AAA &operator=(AAA &&rhs) = default; + }; + + SmallVector vec; + + for (int i = 0; i < 8; i++) { + vec.push_back(AAA{2 * i}); + vec.emplace_back(2 * i + 1); + } + require(vec.is_on_heap() == false); + require(vec.back().data == 15); + + require(vec.empty() == false); + require(vec.size() == 16); + require(vec.capacity() == 16); + + vec.push_back(AAA{42}); + + require(vec.is_on_heap() == true); + require(vec.back().data == 42); + + require(vec.empty() == false); + require(vec.size() == 17); + } + + { // erase + SmallVector vec; + for (int i = 0; i < 8; i++) { + vec.push_back(i); + } + for (int i = 8; i < 16; i++) { + vec.emplace_back(i); + } + + vec.erase(vec.begin() + 8, vec.begin() + 12); + + require(vec.empty() == false); + require(vec.size() == 12); + require(vec.capacity() == 16); + + require(vec[0] == 0); + require(vec[1] == 1); + require(vec[2] == 2); + require(vec[3] == 3); + require(vec[4] == 4); + require(vec[5] == 5); + require(vec[6] == 6); + require(vec[7] == 7); + require(vec[8] == 12); + require(vec[9] == 13); + require(vec[10] == 14); + require(vec[11] == 15); + } + + { // comparison operators + const SmallVector v1 = {1, 2, 3, 4, 5}; + const SmallVector v2 = {1, 2, 3, 4, 6}; + const SmallVector v3 = {1, 2, 3, 4, 5}; + const SmallVector v4 = {1, 2, 3, 4, 6}; + + require(v1 < v2); + require(v1 <= v3); + require(v2 > v1); + require(v2 >= v4); + require(v1 == v3); + require(v2 == v4); + require(v1 != v2); + require(v3 != v4); + } + + printf("OK\n"); +}