Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 27 additions & 45 deletions src/libraries/System.Linq/src/System/Linq/Sum.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace System.Linq
{
Expand Down Expand Up @@ -84,9 +83,6 @@ private static T SumSignedIntegersVectorized<T>(ReadOnlySpan<T> span)
Debug.Assert(Vector<T>.Count > 2);
Debug.Assert(Vector.IsHardwareAccelerated);

ref T ptr = ref MemoryMarshal.GetReference(span);
nuint length = (nuint)span.Length;

// Overflow testing for vectors is based on setting the sign bit of the overflowTracking
// vector for an element if the following are all true:
// - The two elements being summed have the same sign bit. If one element is positive
Expand All @@ -104,71 +100,59 @@ private static T SumSignedIntegersVectorized<T>(ReadOnlySpan<T> span)
// Thus, if we had a sign swap compared to both inputs, then signof(input1) == signof(input2) and
// we must have overflowed.
//
// By bitwise or-ing the overflowTracking vector for each step we can save cycles by testing
// the sign bits less often. If any iteration has the sign bit set in any element it indicates
// there was an overflow.
// By bitwise or-ing the overflowTracking vector throughout the entire loop and
// only testing it once at the end we save the cost of an in-loop test+branch per
// iteration. If any accumulation across the whole input has the sign bit set in
// any element it indicates there was an overflow.
Comment thread
EgorBo marked this conversation as resolved.
//
// Note: The overflow checking in this algorithm is only correct for signed integers.
// If support is ever added for unsigned integers then the overflow check should be:
// overflowTracking |= (input1 & input2) | Vector.AndNot(input1 | input2, result);

Vector<T> accumulator = Vector<T>.Zero;
Vector<T> overflowTracking = Vector<T>.Zero;

// Build a test vector with only the sign bit set in each element.
Vector<T> overflowTestVector = new(T.MinValue);

// Unroll the loop to sum 4 vectors per iteration. This reduces range check
// and overflow check frequency, allows us to eliminate move operations swapping
// accumulators, and may have pipelining benefits.
nuint index = 0;
nuint limit = length - (nuint)Vector<T>.Count * 4;
do
// Unroll the loop to sum 4 vectors per iteration. This allows us to eliminate
// move operations swapping accumulators, and may have pipelining benefits.
while (span.Length >= Vector<T>.Count * 4)
{
// Switch accumulators with each step to avoid an additional move operation
Vector<T> data = Vector.LoadUnsafe(ref ptr, index);
Vector<T> data = Vector.Create(span);
Vector<T> accumulator2 = accumulator + data;
Vector<T> overflowTracking = (accumulator2 ^ accumulator) & (accumulator2 ^ data);
overflowTracking |= (accumulator2 ^ accumulator) & (accumulator2 ^ data);

data = Vector.LoadUnsafe(ref ptr, index + (nuint)Vector<T>.Count);
data = Vector.Create(span.Slice(Vector<T>.Count));
accumulator = accumulator2 + data;
overflowTracking |= (accumulator ^ accumulator2) & (accumulator ^ data);

data = Vector.LoadUnsafe(ref ptr, index + (nuint)Vector<T>.Count * 2);
data = Vector.Create(span.Slice(Vector<T>.Count * 2));
accumulator2 = accumulator + data;
overflowTracking |= (accumulator2 ^ accumulator) & (accumulator2 ^ data);

data = Vector.LoadUnsafe(ref ptr, index + (nuint)Vector<T>.Count * 3);
data = Vector.Create(span.Slice(Vector<T>.Count * 3));
accumulator = accumulator2 + data;
overflowTracking |= (accumulator ^ accumulator2) & (accumulator ^ data);

if ((overflowTracking & overflowTestVector) != Vector<T>.Zero)
{
ThrowHelper.ThrowOverflowException();
}

index += (nuint)Vector<T>.Count * 4;
} while (index < limit);
span = span.Slice(Vector<T>.Count * 4);
}

// Process remaining vectors, if any, without unrolling
limit = length - (nuint)Vector<T>.Count;
if (index < limit)
while (span.Length >= Vector<T>.Count)
{
Vector<T> overflowTracking = Vector<T>.Zero;

do
{
Vector<T> data = Vector.LoadUnsafe(ref ptr, index);
Vector<T> accumulator2 = accumulator + data;
overflowTracking |= (accumulator2 ^ accumulator) & (accumulator2 ^ data);
accumulator = accumulator2;
Vector<T> data = Vector.Create(span);
Vector<T> accumulator2 = accumulator + data;
overflowTracking |= (accumulator2 ^ accumulator) & (accumulator2 ^ data);
accumulator = accumulator2;

index += (nuint)Vector<T>.Count;
} while (index < limit);
span = span.Slice(Vector<T>.Count);
}

if ((overflowTracking & overflowTestVector) != Vector<T>.Zero)
{
ThrowHelper.ThrowOverflowException();
}
if ((overflowTracking & overflowTestVector) != Vector<T>.Zero)
{
ThrowHelper.ThrowOverflowException();
}

// Add the elements in the vector horizontally.
Expand All @@ -180,11 +164,9 @@ private static T SumSignedIntegersVectorized<T>(ReadOnlySpan<T> span)
}

// Add any remaining elements
while (index < length)
foreach (T value in span)
{
checked { result += Unsafe.Add(ref ptr, index); }

index++;
checked { result += value; }
}

return result;
Expand Down