Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 53 additions & 60 deletions stl/inc/mdspan
Original file line number Diff line number Diff line change
Expand Up @@ -32,45 +32,38 @@ public:
using size_type = make_unsigned_t<index_type>;
using rank_type = size_t;

_NODISCARD static constexpr rank_type rank() noexcept {
return sizeof...(_Extents);
}

static_assert(_Is_standard_integer<index_type>,
"IndexType must be a signed or unsigned integer type (N4944 [mdspan.extents.overview]/1.1).");
static_assert(((_Extents == dynamic_extent || _STD in_range<index_type>(_Extents)) && ...),
"Each element of Extents must be either equal to dynamic_extent, or must be representable as a value of type "
"IndexType (N4944 [mdspan.extents.overview]/1.2).");

static constexpr rank_type _Rank = sizeof...(_Extents);
static constexpr rank_type _Rank_dynamic = (static_cast<rank_type>(_Extents == dynamic_extent) + ... + 0);

private:
static constexpr array<rank_type, _Rank> _Static_extents = {_Extents...};

_NODISCARD static _CONSTEVAL auto _Make_dynamic_indices() noexcept {
#pragma warning(push) // TRANSITION, "/analyze:only" BUG?
#pragma warning(disable : 28020) // The expression '0<=_Param_(1)&&_Param_(1)<=1-1' is not true at this call
array<rank_type, rank() + 1> _Result{};
array<rank_type, _Rank + 1> _Result{};
rank_type _Counter = 0;
for (rank_type _Idx = 0; _Idx < rank(); ++_Idx) {
for (rank_type _Idx = 0; _Idx < _Rank; ++_Idx) {
_Result[_Idx] = _Counter;
if (_Static_extents[_Idx] == dynamic_extent) {
++_Counter;
}
}
_Result[rank()] = _Counter;
_Result[_Rank] = _Counter;
return _Result;
#pragma warning(pop) // TRANSITION, "/analyze:only" BUG?
}

static constexpr array<rank_type, rank()> _Static_extents = {_Extents...};
static constexpr array<rank_type, rank() + 1> _Dynamic_indices = _Make_dynamic_indices();

_NODISCARD static constexpr rank_type _Dynamic_index(rank_type _Idx) noexcept {
return _Dynamic_indices[_Idx];
}
static constexpr array<rank_type, _Rank + 1> _Dynamic_indices = _Make_dynamic_indices();

_NODISCARD static _CONSTEVAL auto _Make_dynamic_indices_inv() noexcept {
array<rank_type, rank()> _Result{};
for (rank_type _Idx = 0; _Idx < rank(); ++_Idx) {
for (rank_type _Idx_inv = 0; _Idx_inv < rank(); ++_Idx_inv) {
if (_Dynamic_index(_Idx_inv + 1) == _Idx + 1) {
array<rank_type, _Rank> _Result{};
for (rank_type _Idx = 0; _Idx < _Rank; ++_Idx) {
for (rank_type _Idx_inv = 0; _Idx_inv < _Rank; ++_Idx_inv) {
if (_Dynamic_indices[_Idx_inv + 1] == _Idx + 1) {
_Result[_Idx] = _Idx_inv;
break;
}
Expand All @@ -79,11 +72,7 @@ private:
return _Result;
}

static constexpr array<rank_type, rank()> _Dynamic_indices_inv = _Make_dynamic_indices_inv();

_NODISCARD static constexpr rank_type _Dynamic_index_inv(rank_type _Idx) noexcept {
return _Dynamic_indices_inv[_Idx];
}
static constexpr array<rank_type, _Rank> _Dynamic_indices_inv = _Make_dynamic_indices_inv();

struct _Static_extents_only {
constexpr explicit _Static_extents_only() noexcept = default;
Expand All @@ -96,30 +85,33 @@ private:
}
};

static constexpr rank_type _Rank_dynamic = _Dynamic_index(rank());
conditional_t<_Rank_dynamic != 0, array<index_type, _Rank_dynamic>, _Static_extents_only> _Dynamic_extents{};

public:
_NODISCARD static constexpr rank_type rank() noexcept {
return _Rank;
}

_NODISCARD static constexpr rank_type rank_dynamic() noexcept {
return _Rank_dynamic;
}

_NODISCARD static constexpr size_t static_extent(const rank_type _Idx) noexcept {
_STL_VERIFY(_Idx < rank(), "Index must be less than rank() (N4944 [mdspan.extents.obs]/1)");
_STL_VERIFY(_Idx < _Rank, "Index must be less than rank() (N4944 [mdspan.extents.obs]/1)");
return _Static_extents[_Idx];
}

_NODISCARD constexpr index_type extent(const rank_type _Idx) const noexcept {
_STL_VERIFY(_Idx < rank(), "Index must be less than rank() (N4944 [mdspan.extents.obs]/3)");
_STL_VERIFY(_Idx < _Rank, "Index must be less than rank() (N4944 [mdspan.extents.obs]/3)");
if constexpr (rank_dynamic() == 0) {
return static_cast<index_type>(static_extent(_Idx));
return static_cast<index_type>(_Static_extents[_Idx]);
} else if constexpr (rank_dynamic() == rank()) {
return _Dynamic_extents[_Idx];
} else {
if (static_extent(_Idx) == dynamic_extent) {
return _Dynamic_extents[_Dynamic_index(_Idx)];
if (_Static_extents[_Idx] == dynamic_extent) {
return _Dynamic_extents[_Dynamic_indices[_Idx]];
} else {
return static_cast<index_type>(static_extent(_Idx));
return static_cast<index_type>(_Static_extents[_Idx]);
}
}
}
Expand All @@ -133,16 +125,16 @@ public:
|| (numeric_limits<index_type>::max)() < (numeric_limits<_OtherIndexType>::max)())
extents(const extents<_OtherIndexType, _OtherExtents...>& _Other) noexcept {
auto _It = _Dynamic_extents.begin();
for (rank_type _Idx = 0; _Idx < rank(); ++_Idx) {
for (rank_type _Idx = 0; _Idx < _Rank; ++_Idx) {
_STL_VERIFY(
static_extent(_Idx) == dynamic_extent || _STD cmp_equal(static_extent(_Idx), _Other.extent(_Idx)),
_Static_extents[_Idx] == dynamic_extent || _STD cmp_equal(_Static_extents[_Idx], _Other.extent(_Idx)),
"Value of other.extent(r) must be equal to extent(r) for each r for which extent(r) is a static extent "
"(N4944 [mdspan.extents.cons]/2.1)");
_STL_VERIFY(_STD in_range<index_type>(_Other.extent(_Idx)),
"Value of other.extent(r) must be representable as a value of type index_type for every rank index r "
"(N4944 [mdspan.extents.cons]/2.2)");

if (static_extent(_Idx) == dynamic_extent) {
if (_Static_extents[_Idx] == dynamic_extent) {
*_It = static_cast<index_type>(_Other.extent(_Idx));
++_It;
}
Expand All @@ -167,12 +159,12 @@ public:
} else {
array<index_type, sizeof...(_Exts)> _Exts_arr{static_cast<index_type>(_STD move(_Exts))...};
auto _It = _Dynamic_extents.begin();
for (rank_type _Idx = 0; _Idx < rank(); ++_Idx) {
for (rank_type _Idx = 0; _Idx < _Rank; ++_Idx) {
_STL_VERIFY(
static_extent(_Idx) == dynamic_extent || _STD cmp_equal(static_extent(_Idx), _Exts_arr[_Idx]),
_Static_extents[_Idx] == dynamic_extent || _STD cmp_equal(_Static_extents[_Idx], _Exts_arr[_Idx]),
"Value of exts_arr[r] must be equal to extent(r) for each r for which extent(r) is a static extent "
"(N4944 [mdspan.extents.cons]/7.1)");
if (static_extent(_Idx) == dynamic_extent) {
if (_Static_extents[_Idx] == dynamic_extent) {
*_It = _Exts_arr[_Idx];
++_It;
}
Expand All @@ -199,10 +191,11 @@ public:
requires is_convertible_v<const _OtherIndexType&, index_type>
&& is_nothrow_constructible_v<index_type, const _OtherIndexType&> && (_Size == rank())
constexpr explicit extents(span<_OtherIndexType, _Size> _Exts, index_sequence<_Indices...>) noexcept
: _Dynamic_extents{static_cast<index_type>(_STD as_const(_Exts[_Dynamic_index_inv(_Indices)]))...} {
: _Dynamic_extents{static_cast<index_type>(_STD as_const(_Exts[_Dynamic_indices_inv[_Indices]]))...} {
if constexpr (_Is_standard_integer<_OtherIndexType>) {
for (rank_type _Idx = 0; _Idx < rank(); ++_Idx) {
_STL_VERIFY(static_extent(_Idx) == dynamic_extent || _STD cmp_equal(static_extent(_Idx), _Exts[_Idx]),
for (rank_type _Idx = 0; _Idx < _Rank; ++_Idx) {
_STL_VERIFY(
_Static_extents[_Idx] == dynamic_extent || _STD cmp_equal(_Static_extents[_Idx], _Exts[_Idx]),
"Value of exts[r] must be equal to extent(r) for each r for which extent(r) is a static extent "
"(N4944 [mdspan.extents.cons]/10.1)");
_STL_VERIFY(_Exts[_Idx] >= 0 && _STD in_range<index_type>(_Exts[_Idx]),
Expand Down Expand Up @@ -232,7 +225,7 @@ public:
if constexpr (rank() != sizeof...(_OtherExtents)) {
return false;
} else {
for (rank_type _Idx = 0; _Idx < rank(); ++_Idx) {
for (rank_type _Idx = 0; _Idx < _Rank; ++_Idx) {
if (_STD cmp_not_equal(_Left.extent(_Idx), _Right.extent(_Idx))) {
return false;
}
Expand All @@ -253,7 +246,7 @@ public:
// TRANSITION, LWG ISSUE? I believe that this function should return 'index_type'
_NODISCARD constexpr index_type _Rev_prod_of_extents(const rank_type _Idx) const noexcept {
index_type _Result = 1;
for (rank_type _Dim = _Idx + 1; _Dim < rank(); ++_Dim) {
for (rank_type _Dim = _Idx + 1; _Dim < _Rank; ++_Dim) {
_Result *= extent(_Dim);
}
return _Result;
Expand Down Expand Up @@ -361,7 +354,7 @@ public:
const bool _Verify = [&]<size_t... _Indices>(index_sequence<_Indices...>) {
index_type _Prod = 1;
return ((_Other.stride(_Indices)
== (_Indices == extents_type::rank() - 1
== (_Indices == extents_type::_Rank - 1
? _Prod
: _STD exchange(_Prod, static_cast<index_type>(_Prod * _Exts.extent(_Indices)))))
&& ...);
Expand All @@ -382,7 +375,7 @@ public:
}

_NODISCARD constexpr index_type required_span_size() const noexcept {
return _Exts._Fwd_prod_of_extents(extents_type::rank());
return _Exts._Fwd_prod_of_extents(extents_type::_Rank);
}

template <class... _IndexTypes>
Expand Down Expand Up @@ -419,7 +412,7 @@ public:
_NODISCARD constexpr index_type stride(const rank_type _Idx) const noexcept
requires (extents_type::rank() > 0)
{
_STL_VERIFY(_Idx < extents_type::rank(),
_STL_VERIFY(_Idx < extents_type::_Rank,
"Value of i must be less than extents_type::rank() (N4944 [mdspan.layout.left.obs]/6).");
return _Exts._Fwd_prod_of_extents(_Idx);
}
Expand Down Expand Up @@ -494,7 +487,7 @@ public:
index_type _Prod = stride(0);
return (
(_Other.stride(_Indices)
== (_Indices == extents_type::rank() - 1
== (_Indices == extents_type::_Rank - 1
? _Prod
: _STD exchange(_Prod, static_cast<index_type>(_Prod / _Exts.extent(_Indices + 1)))))
&& ...);
Expand All @@ -515,7 +508,7 @@ public:
}

_NODISCARD constexpr index_type required_span_size() const noexcept {
return _Exts._Fwd_prod_of_extents(extents_type::rank());
return _Exts._Fwd_prod_of_extents(extents_type::_Rank);
}

template <class... _IndexTypes>
Expand Down Expand Up @@ -552,7 +545,7 @@ public:
_NODISCARD constexpr index_type stride(const rank_type _Idx) const noexcept
requires (extents_type::rank() > 0)
{
_STL_VERIFY(_Idx < extents_type::rank(),
_STL_VERIFY(_Idx < extents_type::_Rank,
"Value of i must be less than extents_type::rank() (N4944 [mdspan.layout.right.obs]/6).");
return _Exts._Rev_prod_of_extents(_Idx);
}
Expand Down Expand Up @@ -606,7 +599,7 @@ public:
constexpr mapping() noexcept : _Exts(extents_type{}) {
if constexpr (extents_type::rank() != 0) {
_Strides.back() = 1;
for (rank_type _Idx = extents_type::rank() - 1; _Idx-- > 0;) {
for (rank_type _Idx = extents_type::_Rank - 1; _Idx-- > 0;) {
// TRANSITION USE `_Multiply_with_overflow_check` IN DEBUG MODE
_Strides[_Idx] = _Strides[_Idx + 1] * _Exts.extent(_Idx + 1);
}
Expand All @@ -622,7 +615,7 @@ public:
constexpr mapping(const extents_type& _Exts_, span<_OtherIndexType, extents_type::rank()> _Strides_,
index_sequence<_Indices...>) noexcept
: _Exts(_Exts_), _Strides{static_cast<index_type>(_STD as_const(_Strides_[_Indices]))...} {
for (rank_type _Idx = 0; _Idx < extents_type::rank(); ++_Idx) {
for (rank_type _Idx = 0; _Idx < extents_type::_Rank; ++_Idx) {
// TRANSITION CHECK [mdspan.layout.stride.cons]/4.2 (REQUIRES `_Multiply_with_overflow_check`)
_STL_VERIFY(_Strides[_Idx] > 0, "Value of s[i] must be greater than 0 for all i in the range [0, rank_) "
"(N4944 [mdspan.layout.stride.cons]/4.1).");
Expand Down Expand Up @@ -671,7 +664,7 @@ public:
"[mdspan.layout.stride.cons]/7.3).");
_STL_VERIFY(
_Offset(_Other) == 0, "Value of OFFSET(other) must be equal to 0 (N4944 [mdspan.layout.stride.cons]/7.4).");
for (rank_type _Idx = 0; _Idx < extents_type::rank(); ++_Idx) {
for (rank_type _Idx = 0; _Idx < extents_type::_Rank; ++_Idx) {
const auto _Stride = _Other.stride(_Idx);
_STL_VERIFY(_Stride > 0, "Value of other.stride(r) must be greater than 0 for every rank index r of "
"extents() (N4944 [mdspan.layout.stride.cons]/7.2).");
Expand All @@ -696,7 +689,7 @@ public:
return 1;
} else {
index_type _Result = 1;
for (rank_type _Idx = 0; _Idx < extents_type::rank(); ++_Idx) {
for (rank_type _Idx = 0; _Idx < extents_type::_Rank; ++_Idx) {
const index_type _Ext = _Exts.extent(_Idx);
if (_Ext == 0) {
return 0;
Expand Down Expand Up @@ -737,7 +730,7 @@ public:
if constexpr (extents_type::rank() == 0) {
return true;
} else {
return required_span_size() == _Exts._Fwd_prod_of_extents(extents_type::rank());
return required_span_size() == _Exts._Fwd_prod_of_extents(extents_type::_Rank);
}
}

Expand All @@ -758,7 +751,7 @@ public:
return false;
}

for (rank_type _Idx = 0; _Idx < extents_type::rank(); ++_Idx) {
for (rank_type _Idx = 0; _Idx < extents_type::_Rank; ++_Idx) {
if (_Left.stride(_Idx) != _Right.stride(_Idx)) {
return false;
}
Expand All @@ -777,7 +770,7 @@ private:
if constexpr (extents_type::rank() == 0) {
return _Mapping();
} else {
for (rank_type _Idx = 0; _Idx < extents_type::rank(); ++_Idx) {
for (rank_type _Idx = 0; _Idx < extents_type::_Rank; ++_Idx) {
if (_Mapping.extents().extent(_Idx) == 0) {
return 0;
}
Expand Down Expand Up @@ -855,11 +848,11 @@ public:
"[mdspan.mdspan.overview]/2.3).");

_NODISCARD static constexpr rank_type rank() noexcept {
return extents_type::rank();
return extents_type::_Rank;
}

_NODISCARD static constexpr rank_type rank_dynamic() noexcept {
return extents_type::rank_dynamic();
return extents_type::_Rank_dynamic;
}

_NODISCARD static constexpr size_t static_extent(const rank_type _Idx) noexcept {
Expand Down Expand Up @@ -956,11 +949,11 @@ public:
}

_NODISCARD constexpr size_type size() const noexcept {
return static_cast<size_type>(_Map.extents()._Fwd_prod_of_extents(rank()));
return static_cast<size_type>(_Map.extents()._Fwd_prod_of_extents(extents_type::_Rank));
}

_NODISCARD constexpr bool empty() const noexcept {
for (rank_type _Idx = 0; _Idx < rank(); ++_Idx) {
for (rank_type _Idx = 0; _Idx < extents_type::_Rank; ++_Idx) {
if (_Map.extents().extent(_Idx) == 0) {
return true;
}
Expand Down
6 changes: 3 additions & 3 deletions tests/std/tests/P0009R18_mdspan_layout_stride/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using namespace std;

template <class IndexType, size_t... Extents, size_t... Indices>
constexpr void do_check_members(const extents<IndexType, Extents...>& ext,
const array<IndexType, sizeof...(Extents)> strs, index_sequence<Indices...>) {
const array<IndexType, sizeof...(Extents)>& strs, index_sequence<Indices...>) {
using Ext = extents<IndexType, Extents...>;
using Strides = array<IndexType, sizeof...(Extents)>;
using Mapping = layout_stride::mapping<Ext>;
Expand Down Expand Up @@ -116,7 +116,7 @@ constexpr void do_check_members(const extents<IndexType, Extents...>& ext,
// Tests of 'is_exhaustive' are defined in 'check_is_exhaustive' function [FIXME]
}

{ // Check 'stride' function (intentionally not if constexpr)
{ // Check 'stride' function
for (size_t i = 0; i < strs.size(); ++i) {
same_as<IndexType> decltype(auto) s = m.stride(i);
assert(strs[i] == s);
Expand All @@ -132,7 +132,7 @@ constexpr void do_check_members(const extents<IndexType, Extents...>& ext,
}

template <class IndexType, size_t... Extents>
constexpr void check_members(extents<IndexType, Extents...> ext, const array<IndexType, sizeof...(Extents)> strides) {
constexpr void check_members(extents<IndexType, Extents...> ext, const array<IndexType, sizeof...(Extents)>& strides) {
do_check_members<IndexType, Extents...>(ext, strides, make_index_sequence<sizeof...(Extents)>{});
}

Expand Down