Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ static Vector512<T> CreateVector512T(int i) =>
nanMask = ~Vector512.Equals(result, result);
if (nanMask != Vector512<T>.Zero)
{
return IndexOfFirstMatch(nanMask);
return Vector512.IndexOfWhereAllBitsSet(nanMask);
}
}

Expand All @@ -195,7 +195,7 @@ static Vector512<T> CreateVector512T(int i) =>
nanMask = ~Vector512.Equals(current, current);
if (nanMask != Vector512<T>.Zero)
{
return i + IndexOfFirstMatch(nanMask);
return i + Vector512.IndexOfWhereAllBitsSet(nanMask);
}
}

Expand All @@ -215,7 +215,7 @@ static Vector512<T> CreateVector512T(int i) =>
nanMask = ~Vector512.Equals(current, current);
if (nanMask != Vector512<T>.Zero)
{
int indexInVectorOfFirstMatch = IndexOfFirstMatch(nanMask);
int indexInVectorOfFirstMatch = Vector512.IndexOfWhereAllBitsSet(nanMask);
return typeof(T) == typeof(double) ?
(int)(long)(object)currentIndex.As<T, long>()[indexInVectorOfFirstMatch] :
(int)(object)currentIndex.As<T, int>()[indexInVectorOfFirstMatch];
Expand Down Expand Up @@ -260,7 +260,7 @@ static Vector256<T> CreateVector256T(int i) =>
nanMask = ~Vector256.Equals(result, result);
if (nanMask != Vector256<T>.Zero)
{
return IndexOfFirstMatch(nanMask);
return Vector256.IndexOfWhereAllBitsSet(nanMask);
}
}

Expand All @@ -279,7 +279,7 @@ static Vector256<T> CreateVector256T(int i) =>
nanMask = ~Vector256.Equals(current, current);
if (nanMask != Vector256<T>.Zero)
{
return i + IndexOfFirstMatch(nanMask);
return i + Vector256.IndexOfWhereAllBitsSet(nanMask);
}
}

Expand All @@ -299,7 +299,7 @@ static Vector256<T> CreateVector256T(int i) =>
nanMask = ~Vector256.Equals(current, current);
if (nanMask != Vector256<T>.Zero)
{
int indexInVectorOfFirstMatch = IndexOfFirstMatch(nanMask);
int indexInVectorOfFirstMatch = Vector256.IndexOfWhereAllBitsSet(nanMask);
return typeof(T) == typeof(double) ?
(int)(long)(object)currentIndex.As<T, long>()[indexInVectorOfFirstMatch] :
(int)(object)currentIndex.As<T, int>()[indexInVectorOfFirstMatch];
Expand Down Expand Up @@ -344,7 +344,7 @@ static Vector128<T> CreateVector128T(int i) =>
nanMask = ~Vector128.Equals(result, result);
if (nanMask != Vector128<T>.Zero)
{
return IndexOfFirstMatch(nanMask);
return Vector128.IndexOfWhereAllBitsSet(nanMask);
}
}

Expand All @@ -363,7 +363,7 @@ static Vector128<T> CreateVector128T(int i) =>
nanMask = ~Vector128.Equals(current, current);
if (nanMask != Vector128<T>.Zero)
{
return i + IndexOfFirstMatch(nanMask);
return i + Vector128.IndexOfWhereAllBitsSet(nanMask);
}
}

Expand All @@ -383,7 +383,7 @@ static Vector128<T> CreateVector128T(int i) =>
nanMask = ~Vector128.Equals(current, current);
if (nanMask != Vector128<T>.Zero)
{
int indexInVectorOfFirstMatch = IndexOfFirstMatch(nanMask);
int indexInVectorOfFirstMatch = Vector128.IndexOfWhereAllBitsSet(nanMask);
return typeof(T) == typeof(double) ?
(int)(long)(object)currentIndex.As<T, long>()[indexInVectorOfFirstMatch] :
(int)(object)currentIndex.As<T, int>()[indexInVectorOfFirstMatch];
Expand Down Expand Up @@ -419,15 +419,6 @@ static Vector128<T> CreateVector128T(int i) =>
return curIn;
}

private static int IndexOfFirstMatch<T>(Vector128<T> mask) =>
BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits());

private static int IndexOfFirstMatch<T>(Vector256<T> mask) =>
BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits());

private static int IndexOfFirstMatch<T>(Vector512<T> mask) =>
BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits());

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector256<T> IndexLessThan<T>(Vector256<T> indices1, Vector256<T> indices2) =>
sizeof(T) == sizeof(long) ? Vector256.LessThan(indices1.AsInt64(), indices2.AsInt64()).As<long, T>() :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = IsNaN(result);
if (nanMask != Vector512<T>.Zero)
{
return result.GetElement(IndexOfFirstMatch(nanMask));
return result.GetElement(Vector512.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -277,7 +277,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = ~Vector512.Equals(current, current);
if (nanMask != Vector512<T>.Zero)
{
return current.GetElement(IndexOfFirstMatch(nanMask));
return current.GetElement(Vector512.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -296,7 +296,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = ~Vector512.Equals(current, current);
if (nanMask != Vector512<T>.Zero)
{
return current.GetElement(IndexOfFirstMatch(nanMask));
return current.GetElement(Vector512.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -323,7 +323,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = ~Vector256.Equals(result, result);
if (nanMask != Vector256<T>.Zero)
{
return result.GetElement(IndexOfFirstMatch(nanMask));
return result.GetElement(Vector256.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -342,7 +342,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = ~Vector256.Equals(current, current);
if (nanMask != Vector256<T>.Zero)
{
return current.GetElement(IndexOfFirstMatch(nanMask));
return current.GetElement(Vector256.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -362,7 +362,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = ~Vector256.Equals(current, current);
if (nanMask != Vector256<T>.Zero)
{
return current.GetElement(IndexOfFirstMatch(nanMask));
return current.GetElement(Vector256.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -389,7 +389,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = IsNaN(result);
if (nanMask != Vector128<T>.Zero)
{
return result.GetElement(IndexOfFirstMatch(nanMask));
return result.GetElement(Vector128.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -408,7 +408,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = IsNaN(current);
if (nanMask != Vector128<T>.Zero)
{
return current.GetElement(IndexOfFirstMatch(nanMask));
return current.GetElement(Vector128.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -427,7 +427,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = IsNaN(current);
if (nanMask != Vector128<T>.Zero)
{
return current.GetElement(IndexOfFirstMatch(nanMask));
return current.GetElement(Vector128.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand Down
55 changes: 25 additions & 30 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs
Original file line number Diff line number Diff line change
Expand Up @@ -531,16 +531,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
Vector128<byte> search = Vector128.Load(searchSpace + offset);

// Same method as below
uint matches = Vector128.Equals(Vector128<byte>.Zero, search).ExtractMostSignificantBits();
if (matches == 0)
Vector128<byte> cmp = Vector128.Equals(Vector128<byte>.Zero, search);
if (cmp == Vector128<byte>.Zero)
Comment on lines -534 to +535
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EgorBo, so on x64 this is basically going to do:

                                        ; Approx 8 total cycles
    vxorps    xmm0, xmm0, xmm0          ; 0 cycles
    vpcmpeqb  xmm0, xmm0, xmm1          ; 1 cycle
    vptest    xmm0, xmm0                ; 7 cycles
    jz        SHORT NO_MATCH            ; fused

MATCH:                                  ; Approx 10 total cycles
    vpmovmskb eax, xmm0                 ; 5 cycles
    tzcnt     eax, eax                  ; 1 cycle
    mov       ecx, -1                   ; 1 cycle
    cmp       eax, 32                   ; 1 cycle
    cmove     eax, ecx                  ; 1 cycle
    add       eax, edx                  ; 1 cycle
    ret                                 ; return

NO_MATCH:
    ; ...

and on Arm64 (neoverse v2):

                                        ; Approx 7 total cycles
    cmeq    v16.16b, v0.16b, #0         ; 2 cycles
    umaxp   v17.4s, v16.4s, v16.4s      ; 2 cycles
    umov    x1, v17.d[0]                ; 2 cycles
    cmp     x1, #0                      ; 1 cycle
    b.eq    NO_MATCH                    ; branch

MATCH:                                  ; Approx 10 total cycles
    shrn    v16.8b, v16.8h, #4          ; 2 cycles
    umov    x1, v16.d[0]                ; 2 cycles
    rbit    x1, x1                      ; 1 cycle
    clz     x1, x1                      ; 1 cycle
    lsr     w1, w1, #2                  ; 1 cycles
    movn    w2, #0                      ; 1 cycle
    cmp     w1, #16                     ; 1 cycle
    csel    w1, w1, w2, ne              ; fused
    add     w0, w0, w1                  ; 1 cycle
    ret     lr                          ; return

NO_MATCH:
    ; ...

More ideally the JIT could recognize this general pattern and generate this instead for x64:

                                        ; Approx 7 total cycles
    vxorps    xmm0, xmm0, xmm0          ; 0 cycles
    vpcmpeqb  xmm0, xmm0, xmm1          ; 1 cycle
    vpmovmskb eax, xmm0                 ; 5 cycles
    cmp       eax, 0                    ; 1 cycle
    jz        SHORT NO_MATCH            ; fused

MATCH:                                  ; Approx 2 total cycle
    tzcnt     eax, eax                  ; 1 cycle
    add       eax, edx                  ; 1 cycle
    ret                                 ; return

NO_MATCH:
    ; ...

and this on Arm64:

                                        ; Approx 7 total cycles
    cmeq    v16.16b, v0.16b, #0         ; 2 cycles
    shrn    v16.8b, v16.8h, #4          ; 2 cycles
    umov    x1, v16.d[0]                ; 2 cycles
    cmp     w1, #0                      ; 1 cycle
    b.eq    NO_MATCH

MATCH:                                  ; Approx 4 total cycle
    rbit    x1, x1                      ; 1 cycle
    clz     x1, x1                      ; 1 cycle
    lsr     w1, w1, #2                  ; 1 cycles
    add     w0, w0, w1                  ; 1 cycle
    ret     lr                          ; returnmm

NO_MATCH:
    ; ...

This would make it significantly cheaper for both, but I think requires us to recognize the != Zero followed by an Count/IndexOf/LastIndexOf pattern. Specifically I think CSE would trivially handle this for Arm64, but on x64 we'd need to transform the != Zero in that case so CSE could kick in.

What are your thoughts on this?


The alternative is we setup the managed code to look like this:

int index = Vector128.IndexOf(search, 0);

if (index < 0)
{
    // Zero flags set so no matches
    offset += (nuint)Vector128<byte>.Count;
}
else
{
    // Find bitflag offset of first match and add to current offset
    return (int)(offset + (uint)Vector128.IndexOfFirstMatch(cmp));
}

Then we'd get this (roughly) on x64:

                                        ; Approx 11 total cycles
    vxorps    xmm0, xmm0, xmm0          ; 0 cycles
    vpcmpeqb  xmm0, xmm0, xmm1          ; 1 cycle
    vpmovmskb eax, xmm0                 ; 5 cycles
    tzcnt     eax, eax                  ; 1 cycle
    mov       ecx, -1                   ; 1 cycle
    cmp       eax, 32                   ; 1 cycle
    cmove     eax, ecx                  ; 1 cycle
    cmp       eax, 0                    ; 1 cycle
    jl        SHORT NO_MATCH            ; fused

MATCH:                                  ; Approx 1 total cycle
    add       eax, edx                  ; 1 cycle
    ret                                 ; return

NO_MATCH:
    ; ...

and this on Arm64:

                                        ; Approx 10 total cycles
    cmeq    v16.16b, v0.16b, #0         ; 2 cycles
    shrn    v16.8b, v16.8h, #4          ; 2 cycles
    umov    x1, v16.d[0]                ; 2 cycles
    rbit    x1, x1                      ; 1 cycle
    clz     x1, x1                      ; 1 cycle
    lsr     w1, w1, #2                  ; 1 cycles
    cmp     w1, #0                      ; 1 cycle
    b.ge    NO_MATCH

MATCH:                                  ; Approx 1 total cycle
    add     w0, w0, w1                  ; 1 cycle
    ret     lr                          ; returnmm

NO_MATCH:
    ; ...

This is a little less than half the cost on match on both platforms, but has slightly higher cost for the no match scenario.

But I expect this is also difficult to pattern match and handle to get it to generate what we want in the first scenario, right?

We should probably pick one and have that be the "recommended pattern" where we then have the JIT handle it for the ideal codegen. -- The "other" other thing we could do is use Vector128.AnyWhereAllBitsSet(mask) instead of mask != Vector128<T>.Zero, which might then be easier to optimize overall, but interested in your thoughts so we can work towards getting it optimized and have managed follow our desired shape.

{
// Zero flags set so no matches
offset += (nuint)Vector128<byte>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector128.IndexOfFirstMatch(cmp));
}
}

Expand All @@ -553,16 +553,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
Vector256<byte> search = Vector256.Load(searchSpace + offset);

// Same method as below
uint matches = Vector256.Equals(Vector256<byte>.Zero, search).ExtractMostSignificantBits();
if (matches == 0)
Vector256<byte> cmp = Vector256.Equals(Vector256<byte>.Zero, search);
if (cmp == Vector256<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector256<byte>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector256.IndexOfFirstMatch(cmp));
}
}
lengthToExamine = GetByteVector512SpanLength(offset, Length);
Expand All @@ -571,18 +571,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
do
{
Vector512<byte> search = Vector512.Load(searchSpace + offset);
ulong matches = Vector512.Equals(Vector512<byte>.Zero, search).ExtractMostSignificantBits();
// Note that MoveMask has converted the equal vector elements into a set of bit flags,
// So the bit position in 'matches' corresponds to the element offset.
if (matches == 0)
Vector512<byte> cmp = Vector512.Equals(Vector512<byte>.Zero, search);
if (cmp == Vector512<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector512<byte>.Count;
continue;
}

// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector512.IndexOfFirstMatch(cmp));
} while (lengthToExamine > offset);
}

Expand All @@ -592,16 +590,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
Vector256<byte> search = Vector256.Load(searchSpace + offset);

// Same method as above
uint matches = Vector256.Equals(Vector256<byte>.Zero, search).ExtractMostSignificantBits();
if (matches == 0)
Vector256<byte> cmp = Vector256.Equals(Vector256<byte>.Zero, search);
if (cmp == Vector256<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector256<byte>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector256.IndexOfFirstMatch(cmp));
}
}

Expand All @@ -611,16 +609,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
Vector128<byte> search = Vector128.Load(searchSpace + offset);

// Same method as above
uint matches = Vector128.Equals(Vector128<byte>.Zero, search).ExtractMostSignificantBits();
if (matches == 0)
Vector128<byte> cmp = Vector128.Equals(Vector128<byte>.Zero, search);
if (cmp == Vector128<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector128<byte>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector128.IndexOfFirstMatch(cmp));
}
}

Expand All @@ -644,16 +642,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
Vector128<byte> search = Vector128.Load(searchSpace + offset);

// Same method as below
uint matches = Vector128.Equals(Vector128<byte>.Zero, search).ExtractMostSignificantBits();
if (matches == 0)
Vector128<byte> cmp = Vector128.Equals(Vector128<byte>.Zero, search);
if (cmp == Vector128<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector128<byte>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector128.IndexOfFirstMatch(cmp));
}
}

Expand All @@ -663,18 +661,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
do
{
Vector256<byte> search = Vector256.Load(searchSpace + offset);
uint matches = Vector256.Equals(Vector256<byte>.Zero, search).ExtractMostSignificantBits();
// Note that MoveMask has converted the equal vector elements into a set of bit flags,
// So the bit position in 'matches' corresponds to the element offset.
if (matches == 0)
Vector256<byte> cmp = Vector256.Equals(Vector256<byte>.Zero, search);
if (cmp == Vector256<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector256<byte>.Count;
continue;
}

// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector256.IndexOfFirstMatch(cmp));
} while (lengthToExamine > offset);
}

Expand All @@ -684,16 +680,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
Vector128<byte> search = Vector128.Load(searchSpace + offset);

// Same method as above
uint matches = Vector128.Equals(Vector128<byte>.Zero, search).ExtractMostSignificantBits();
if (matches == 0)
Vector128<byte> cmp = Vector128.Equals(Vector128<byte>.Zero, search);
if (cmp == Vector128<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector128<byte>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector128.IndexOfFirstMatch(cmp));
}
}

Expand Down Expand Up @@ -724,8 +720,7 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
}

// Find bitflag offset of first match and add to current offset
uint matches = compareResult.ExtractMostSignificantBits();
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector128.IndexOfFirstMatch(compareResult));
}

if (offset < (nuint)(uint)Length)
Expand Down
Loading
Loading