diff --git a/src/libraries/System.Linq/src/System/Linq/Sum.cs b/src/libraries/System.Linq/src/System/Linq/Sum.cs index 481997b8f82633..a9187f9ebb40b6 100644 --- a/src/libraries/System.Linq/src/System/Linq/Sum.cs +++ b/src/libraries/System.Linq/src/System/Linq/Sum.cs @@ -5,7 +5,6 @@ using System.Diagnostics; using System.Numerics; using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; namespace System.Linq { @@ -84,9 +83,6 @@ private static T SumSignedIntegersVectorized(ReadOnlySpan span) Debug.Assert(Vector.Count > 2); Debug.Assert(Vector.IsHardwareAccelerated); - ref T ptr = ref MemoryMarshal.GetReference(span); - nuint length = (nuint)span.Length; - // Overflow testing for vectors is based on setting the sign bit of the overflowTracking // vector for an element if the following are all true: // - The two elements being summed have the same sign bit. If one element is positive @@ -104,71 +100,59 @@ private static T SumSignedIntegersVectorized(ReadOnlySpan span) // Thus, if we had a sign swap compared to both inputs, then signof(input1) == signof(input2) and // we must have overflowed. // - // By bitwise or-ing the overflowTracking vector for each step we can save cycles by testing - // the sign bits less often. If any iteration has the sign bit set in any element it indicates - // there was an overflow. + // By bitwise or-ing the overflowTracking vector throughout the entire loop and + // only testing it once at the end we save the cost of an in-loop test+branch per + // iteration. If any accumulation across the whole input has the sign bit set in + // any element it indicates there was an overflow. // // Note: The overflow checking in this algorithm is only correct for signed integers. // If support is ever added for unsigned integers then the overflow check should be: // overflowTracking |= (input1 & input2) | Vector.AndNot(input1 | input2, result); Vector accumulator = Vector.Zero; + Vector overflowTracking = Vector.Zero; // Build a test vector with only the sign bit set in each element. Vector overflowTestVector = new(T.MinValue); - // Unroll the loop to sum 4 vectors per iteration. This reduces range check - // and overflow check frequency, allows us to eliminate move operations swapping - // accumulators, and may have pipelining benefits. - nuint index = 0; - nuint limit = length - (nuint)Vector.Count * 4; - do + // Unroll the loop to sum 4 vectors per iteration. This allows us to eliminate + // move operations swapping accumulators, and may have pipelining benefits. + while (span.Length >= Vector.Count * 4) { // Switch accumulators with each step to avoid an additional move operation - Vector data = Vector.LoadUnsafe(ref ptr, index); + Vector data = Vector.Create(span); Vector accumulator2 = accumulator + data; - Vector overflowTracking = (accumulator2 ^ accumulator) & (accumulator2 ^ data); + overflowTracking |= (accumulator2 ^ accumulator) & (accumulator2 ^ data); - data = Vector.LoadUnsafe(ref ptr, index + (nuint)Vector.Count); + data = Vector.Create(span.Slice(Vector.Count)); accumulator = accumulator2 + data; overflowTracking |= (accumulator ^ accumulator2) & (accumulator ^ data); - data = Vector.LoadUnsafe(ref ptr, index + (nuint)Vector.Count * 2); + data = Vector.Create(span.Slice(Vector.Count * 2)); accumulator2 = accumulator + data; overflowTracking |= (accumulator2 ^ accumulator) & (accumulator2 ^ data); - data = Vector.LoadUnsafe(ref ptr, index + (nuint)Vector.Count * 3); + data = Vector.Create(span.Slice(Vector.Count * 3)); accumulator = accumulator2 + data; overflowTracking |= (accumulator ^ accumulator2) & (accumulator ^ data); - if ((overflowTracking & overflowTestVector) != Vector.Zero) - { - ThrowHelper.ThrowOverflowException(); - } - - index += (nuint)Vector.Count * 4; - } while (index < limit); + span = span.Slice(Vector.Count * 4); + } // Process remaining vectors, if any, without unrolling - limit = length - (nuint)Vector.Count; - if (index < limit) + while (span.Length >= Vector.Count) { - Vector overflowTracking = Vector.Zero; - - do - { - Vector data = Vector.LoadUnsafe(ref ptr, index); - Vector accumulator2 = accumulator + data; - overflowTracking |= (accumulator2 ^ accumulator) & (accumulator2 ^ data); - accumulator = accumulator2; + Vector data = Vector.Create(span); + Vector accumulator2 = accumulator + data; + overflowTracking |= (accumulator2 ^ accumulator) & (accumulator2 ^ data); + accumulator = accumulator2; - index += (nuint)Vector.Count; - } while (index < limit); + span = span.Slice(Vector.Count); + } - if ((overflowTracking & overflowTestVector) != Vector.Zero) - { - ThrowHelper.ThrowOverflowException(); - } + if ((overflowTracking & overflowTestVector) != Vector.Zero) + { + ThrowHelper.ThrowOverflowException(); } // Add the elements in the vector horizontally. @@ -180,11 +164,9 @@ private static T SumSignedIntegersVectorized(ReadOnlySpan span) } // Add any remaining elements - while (index < length) + foreach (T value in span) { - checked { result += Unsafe.Add(ref ptr, index); } - - index++; + checked { result += value; } } return result;