diff --git a/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln b/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln index 5a1bf792736d94..dc07cb0b831459 100644 --- a/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln +++ b/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln @@ -1,4 +1,8 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.8.34205.153 +MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestUtilities", "..\Common\tests\TestUtilities\TestUtilities.csproj", "{9F20CEA1-2216-4432-BBBD-F01E05D17F23}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bcl.Numerics", "..\Microsoft.Bcl.Numerics\ref\Microsoft.Bcl.Numerics.csproj", "{D311ABE4-10A9-4BB1-89CE-6358C55501A8}" @@ -33,11 +37,11 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{AB415F5A-75E EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{083161E5-6049-4D84-9739-9D7797D7117D}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "tools\gen", "{841A2FA4-A95F-4612-A8B9-AD2EF769BC71}" +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{841A2FA4-A95F-4612-A8B9-AD2EF769BC71}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "tools\src", "{DF0561A1-3AB8-4B51-AFB4-392EE1DD6247}" +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DF0561A1-3AB8-4B51-AFB4-392EE1DD6247}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "tools\ref", "{7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB}" +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B}" EndProject @@ -105,23 +109,27 @@ Global EndGlobalSection GlobalSection(NestedProjects) = preSolution {9F20CEA1-2216-4432-BBBD-F01E05D17F23} = {DE94CA7D-BB10-4865-85A6-6B694631247F} - {4AF6A02D-82C8-4898-9EDF-01F107C25061} = {DE94CA7D-BB10-4865-85A6-6B694631247F} {D311ABE4-10A9-4BB1-89CE-6358C55501A8} = {6BC42E6D-848C-4533-B715-F116E7DB3610} - {21CB448A-3882-4337-B416-D1A3E0BCFFC5} = {6BC42E6D-848C-4533-B715-F116E7DB3610} {1578185F-C4FA-4866-936B-E62AAEDD03B7} = {AB415F5A-75E5-4E03-8A92-15CEDEC4CD3A} + {21CB448A-3882-4337-B416-D1A3E0BCFFC5} = {6BC42E6D-848C-4533-B715-F116E7DB3610} {848DD000-3D22-4A25-A9D9-05AFF857A116} = {AB415F5A-75E5-4E03-8A92-15CEDEC4CD3A} + {4AF6A02D-82C8-4898-9EDF-01F107C25061} = {DE94CA7D-BB10-4865-85A6-6B694631247F} {4588351F-4233-4957-B84C-7F8E22B8888A} = {083161E5-6049-4D84-9739-9D7797D7117D} {DB954E01-898A-4FE2-A3AA-180D041AB08F} = {083161E5-6049-4D84-9739-9D7797D7117D} {04FC0651-B9D0-448A-A28B-11B1D4A897F4} = {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} {683A7D28-CC55-4375-848D-E659075ECEE4} = {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} - {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} {1CBEAEA8-2CA1-4B07-9930-35A785205852} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} {BA7828B1-7953-47A0-AE5A-E22B501C4BD0} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} - {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} {57E57290-3A6A-43F8-8764-D4DC8151F89C} = {7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB} + {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} + {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} {7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {10A5F2C3-5230-4916-9D4D-BBDB94851037} EndGlobalSection + GlobalSection(SharedMSBuildProjectFiles) = preSolution + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{683a7d28-cc55-4375-848d-e659075ecee4}*SharedItemsImports = 5 + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{ba7828b1-7953-47a0-ae5a-e22b501c4bd0}*SharedItemsImports = 5 + EndGlobalSection EndGlobal diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 954f4924d81c63..d77cd743a0713f 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -13,6 +13,20 @@ namespace System.Numerics.Tensors { public static unsafe partial class TensorPrimitives { + /// Defines the threshold, in bytes, at which non-temporal stores will be used. + /// + /// A non-temporal store is one that allows the CPU to bypass the cache when writing to memory. + /// + /// This can be beneficial when working with large amounts of memory where the writes would otherwise + /// cause large amounts of repeated updates and evictions. The hardware optimization manuals recommend + /// the threshold to be roughly half the size of the last level of on-die cache -- that is, if you have approximately + /// 4MB of L3 cache per core, you'd want this to be approx. 1-2MB, depending on if hyperthreading was enabled. + /// + /// However, actually computing the amount of L3 cache per core can be tricky or error prone. Native memcpy + /// algorithms use a constant threshold that is typically around 256KB and we match that here for simplicity. This + /// threshold accounts for most processors in the last 10-15 years that had approx. 1MB L3 per core and support + /// hyperthreading, giving a per core last level cache of approx. 512KB. + /// private const nuint NonTemporalByteThreshold = 256 * 1024; /// @@ -1355,7 +1369,7 @@ static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) Vector128 vector3; Vector128 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -1505,10 +1519,10 @@ static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -1537,12 +1551,11 @@ static void Vectorized128Small(ref float xRef, ref float dRef, nuint remainder) case 1: { dRef = TUnaryOperator.Invoke(xRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -1597,7 +1610,7 @@ static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) Vector256 vector3; Vector256 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -1747,10 +1760,10 @@ static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -1804,12 +1817,11 @@ static void Vectorized256Small(ref float xRef, ref float dRef, nuint remainder) case 1: { dRef = TUnaryOperator.Invoke(xRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -1865,7 +1877,7 @@ static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) Vector512 vector3; Vector512 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -2015,10 +2027,10 @@ static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -2101,12 +2113,11 @@ static void Vectorized512Small(ref float xRef, ref float dRef, nuint remainder) case 1: { dRef = TUnaryOperator.Invoke(xRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -2138,684 +2149,5399 @@ private static void InvokeSpanSpanIntoSpan( ValidateInputOutputSpanNonOverlapping(x, destination); ValidateInputOutputSpanNonOverlapping(y, destination); + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); #if NET8_0_OR_GREATER if (Vector512.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) + if (remainder >= (uint)(Vector512.Count)) { - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); - - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.LoadUnsafe(ref dRef, lastVectorIndex), - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); - } + Vectorized512(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - return; + Vectorized512Small(ref xRef, ref yRef, ref dRef, remainder); } + + return; } #endif if (Vector256.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) + if (remainder >= (uint)(Vector256.Count)) { - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); - - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.LoadUnsafe(ref dRef, lastVectorIndex), - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); - } + Vectorized256(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - return; + Vectorized256Small(ref xRef, ref yRef, ref dRef, remainder); } + + return; } if (Vector128.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) + if (remainder >= (uint)(Vector128.Count)) { - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); - - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.LoadUnsafe(ref dRef, lastVectorIndex), - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); - } - - return; + Vectorized128(ref xRef, ref yRef, ref dRef, remainder); } - } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i)); + Vectorized128Small(ref xRef, ref yRef, ref dRef, remainder); + } - i++; + return; } - } - /// - /// Performs an element-wise operation on and , - /// and writes the results to . - /// - /// - /// Specifies the operation to perform on each element loaded from with . - /// - private static void InvokeSpanScalarIntoSpan( - ReadOnlySpan x, float y, Span destination) - where TBinaryOperator : struct, IBinaryOperator => - InvokeSpanScalarIntoSpan(x, y, destination); + // This is the software fallback when no acceleration is available + // It requires no branches to hit - /// - /// Performs an element-wise operation on and , - /// and writes the results to . - /// - /// - /// Specifies the operation to perform on each element loaded from . - /// It is not used with . - /// - /// - /// Specifies the operation to perform on the transformed value from with . - /// - private static void InvokeSpanScalarIntoSpan( - ReadOnlySpan x, float y, Span destination) - where TTransformOperator : struct, IUnaryOperator - where TBinaryOperator : struct, IBinaryOperator - { - if (x.Length > destination.Length) + SoftwareFallback(ref xRef, ref yRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float dRef, nuint length) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i)); + } } - ValidateInputOutputSpanNonOverlapping(x, destination); + static void Vectorized128(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + // Preload the beginning and end so that overlapping accesses don't negatively impact the data -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) { - Vector512 yVec = Vector512.Create(y); + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - // Loop handling one vector at a time. - do + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) { - TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i)), - yVec).StoreUnsafe(ref dRef, (uint)i); + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.LoadUnsafe(ref dRef, lastVectorIndex), - TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex)), - yVec)).StoreUnsafe(ref dRef, lastVectorIndex); - } + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; - return; - } - } -#endif + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. - if (Vector256.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) - { - Vector256 yVec = Vector256.Create(y); + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i)), - yVec).StoreUnsafe(ref dRef, (uint)i); + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.LoadUnsafe(ref dRef, lastVectorIndex), - TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex)), - yVec)).StoreUnsafe(ref dRef, lastVectorIndex); - } + remainder -= misalignment; + } - return; - } - } + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; - if (Vector128.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) - { - Vector128 yVec = Vector128.Create(y); + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i)), - yVec).StoreUnsafe(ref dRef, (uint)i); + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.LoadUnsafe(ref dRef, lastVectorIndex), - TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex)), - yVec)).StoreUnsafe(ref dRef, lastVectorIndex); - } + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); - return; - } - } + // We load, process, and store the next four vectors - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, i)), - y); + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); - i++; - } - } + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); - /// - /// Performs an element-wise operation on , , and , - /// and writes the results to . - /// - /// - /// Specifies the operation to perform on the pair-wise elements loaded from , , - /// and . - /// - private static void InvokeSpanSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) - where TTernaryOperator : struct, ITernaryOperator - { - if (x.Length != y.Length || x.Length != z.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + // We adjust the source and destination references, then update + // the count of remaining elements to process. - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - ValidateInputOutputSpanNonOverlapping(x, destination); - ValidateInputOutputSpanNonOverlapping(y, destination); - ValidateInputOutputSpanNonOverlapping(z, destination); + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) { - // Loop handling one vector at a time. - do + case 8: { - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i), - Vector512.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } - i += Vector512.Count; + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 6: { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex), - Vector512.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; } - return; + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } } } -#endif - if (Vector256.IsHardwareAccelerated) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) + switch (remainder) { - // Loop handling one vector at a time. - do + case 3: { - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i), - Vector256.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } - i += Vector256.Count; + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 1: { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex), - Vector256.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); + dRef = TBinaryOperator.Invoke(xRef, yRef); + goto case 0; } - return; + case 0: + { + break; + } } } - if (Vector128.IsHardwareAccelerated) + static void Vectorized256(ref float xRef, ref float yRef, ref float dRef, nuint remainder) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) { - // Loop handling one vector at a time. - do + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) { - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i), - Vector128.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs - i += Vector128.Count; + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 7: { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex), - Vector128.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; } - return; - } - } + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef)); + Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from with . + /// + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination) + where TBinaryOperator : struct, IBinaryOperator => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from . + /// It is not used with . + /// + /// + /// Specifies the operation to perform on the transformed value from with . + /// + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination) + where TTransformOperator : struct, IUnaryOperator + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, y, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, y, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, y, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, i)), + y); + } + } + + static void Vectorized128(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + yVec); + Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))), + yVec); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + Vector128.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 yVec = Vector512.Create(y); + + Vector512 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)), + yVec); + Vector512 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))), + yVec); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + yVec); + Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + Vector256.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + Vector128.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from , , + /// and . + /// + private static void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length || x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + Unsafe.Add(ref zRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef), + Vector512.LoadUnsafe(ref zRef)); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and + /// with . + /// + private static void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, z, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, float z, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + z); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 zVec = Vector256.Create(z); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + zVec); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + zVec); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 zVec = Vector512.Create(z); + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef), + zVec); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count)), + zVec); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 zVec = Vector256.Create(z); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + zVec); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise element loaded from , with , + /// and the element loaded from . + /// + private static void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float zRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + y, + Unsafe.Add(ref zRef, i)); + } + } + + static void Vectorized128(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + yVec, + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.Create(y), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 yVec = Vector512.Create(y); + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + yVec, + Vector512.LoadUnsafe(ref zRef)); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - Unsafe.Add(ref zRef, i)); + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); - i++; - } - } + // We load, process, and store the next four vectors - /// - /// Performs an element-wise operation on , , and , - /// and writes the results to . - /// - /// - /// Specifies the operation to perform on the pair-wise elements loaded from and - /// with . - /// - private static void InvokeSpanSpanScalarIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) - where TTernaryOperator : struct, ITernaryOperator - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); - ValidateInputOutputSpanNonOverlapping(x, destination); - ValidateInputOutputSpanNonOverlapping(y, destination); + // We adjust the source and destination references, then update + // the count of remaining elements to process. - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + xPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) - { - Vector512 zVec = Vector512.Create(z); + remainder -= (uint)(Vector512.Count * 8); + } + } - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i), - zVec).StoreUnsafe(ref dRef, (uint)i); + // Adjusting the refs here allows us to avoid pinning for very small inputs - i += Vector512.Count; + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex), - zVec)).StoreUnsafe(ref dRef, lastVectorIndex); - } + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. - return; - } - } -#endif + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); - if (Vector256.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) + switch (remainder / (uint)(Vector512.Count)) { - Vector256 zVec = Vector256.Create(z); - - // Loop handling one vector at a time. - do + case 8: { - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i), - zVec).StoreUnsafe(ref dRef, (uint)i); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } - i += Vector256.Count; + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 6: { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex), - zVec)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; } - return; - } - } + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } - if (Vector128.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) - { - Vector128 zVec = Vector128.Create(z); + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } - // Loop handling one vector at a time. - do + case 3: { - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i), - zVec).StoreUnsafe(ref dRef, (uint)i); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } - i += Vector128.Count; + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 1: { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex), - zVec)).StoreUnsafe(ref dRef, lastVectorIndex); + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; } - return; + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } } } - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - z); - - i++; - } - } - - /// - /// Performs an element-wise operation on , , and , - /// and writes the results to . - /// - /// - /// Specifies the operation to perform on the pair-wise element loaded from , with , - /// and the element loaded from . - /// - private static void InvokeSpanScalarSpanIntoSpan( - ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) - where TTernaryOperator : struct, ITernaryOperator - { - if (x.Length != z.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + Vector256 yVec = Vector256.Create(y); - ValidateInputOutputSpanNonOverlapping(x, destination); - ValidateInputOutputSpanNonOverlapping(z, destination); + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + yVec, + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) - { - Vector512 yVec = Vector512.Create(y); + break; + } - // Loop handling one vector at a time. - do + case 8: { - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector512.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.Create(y), + Vector256.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); - i += Vector512.Count; + break; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 7: + case 6: + case 5: { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector512.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); - } + Debug.Assert(Vector128.IsHardwareAccelerated); - return; - } - } -#endif + Vector128 yVec = Vector128.Create(y); - if (Vector256.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) - { - Vector256 yVec = Vector256.Create(y); + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector256.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); - i += Vector256.Count; + break; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 4: { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector256.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); - } + Debug.Assert(Vector128.IsHardwareAccelerated); - return; - } - } + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.Create(y), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); - if (Vector128.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) - { - Vector128 yVec = Vector128.Create(y); + break; + } - // Loop handling one vector at a time. - do + case 3: { - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector128.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } - i += Vector128.Count; + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 1: { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector128.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + goto case 0; } - return; + case 0: + { + break; + } } } - - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), - y, - Unsafe.Add(ref zRef, i)); - - i++; - } +#endif } /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index cde38a70fbbef4..5e6e9ac6252e3c 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -522,10 +522,10 @@ static void Vectorized(ref float xRef, ref float dRef, nuint remainder, TUnaryOp { // Store the last block, which includes any elements that wouldn't fill a full vector AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block AsVector(ref dRefBeg) = beg; @@ -578,12 +578,11 @@ static void VectorizedSmall(ref float xRef, ref float dRef, nuint remainder, TUn case 1: { dRef = op.Invoke(xRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -597,7 +596,7 @@ static void VectorizedSmall(ref float xRef, ref float dRef, nuint remainder, TUn /// /// Specifies the operation to perform on the pair-wise elements loaded from and . /// - private static void InvokeSpanSpanIntoSpan( + private static unsafe void InvokeSpanSpanIntoSpan( ReadOnlySpan x, ReadOnlySpan y, Span destination, TBinaryOperator op = default) where TBinaryOperator : struct, IBinaryOperator { @@ -614,48 +613,298 @@ private static void InvokeSpanSpanIntoSpan( ValidateInputOutputSpanNonOverlapping(x, destination); ValidateInputOutputSpanNonOverlapping(y, destination); + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); if (Vector.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float dRef, nuint length, TBinaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float yRef, ref float dRef, nuint remainder, TBinaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) { - // Loop handling one vector at a time. - do + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - AsVector(ref yRef, i)); + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; - i += Vector.Count; + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: { - int lastVectorIndex = x.Length - Vector.Count; - ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); - dest = Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - dest, - op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex))); + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; } - return; + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } } } - while (i < x.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, ref float dRef, nuint remainder, TBinaryOperator op = default) { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i)); + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef); + goto case 0; + } - i++; + case 0: + { + break; + } + } } } @@ -682,7 +931,7 @@ private static void InvokeSpanScalarIntoSpan( /// /// Specifies the operation to perform on the transformed value from with . /// - private static void InvokeSpanScalarIntoSpan( + private static unsafe void InvokeSpanScalarIntoSpan( ReadOnlySpan x, float y, Span destination, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) where TTransformOperator : struct, IUnaryOperator where TBinaryOperator : struct, IBinaryOperator @@ -694,200 +943,294 @@ private static void InvokeSpanScalarIntoSpan.Count; - if (oneVectorFromEnd >= 0) + if (remainder >= (uint)(Vector.Count)) { - // Loop handling one vector at a time. - Vector yVec = new(y); - do - { - AsVector(ref dRef, i) = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, i)), - yVec); + Vectorized(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - i += Vector.Count; - } - while (i <= oneVectorFromEnd); + VectorizedSmall(ref xRef, y, ref dRef, remainder); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - int lastVectorIndex = x.Length - Vector.Count; - ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); - dest = Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - dest, - binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, lastVectorIndex)), yVec)); - } + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref dRef, remainder); - return; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float dRef, nuint length, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, (nint)(i))), + y); } } - // Loop handling one element at a time. - while (i < x.Length) + static void Vectorized(ref float xRef, float y, ref float dRef, nuint remainder, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) { - Unsafe.Add(ref dRef, i) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, i)), - y); + ref float dRefBeg = ref dRef; - i++; - } - } + // Preload the beginning and end so that overlapping accesses don't negatively impact the data - /// - /// Performs an element-wise operation on , , and , - /// and writes the results to . - /// - /// - /// Specifies the operation to perform on the pair-wise elements loaded from , , - /// and . - /// - private static void InvokeSpanSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) - where TTernaryOperator : struct, ITernaryOperator - { - if (x.Length != y.Length || x.Length != z.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + Vector yVec = new Vector(y); - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + Vector beg = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef)), + yVec); + Vector end = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))), + yVec); - ValidateInputOutputSpanNonOverlapping(x, destination); - ValidateInputOutputSpanNonOverlapping(y, destination); - ValidateInputOutputSpanNonOverlapping(z, destination); + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; - if (Vector.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))), + yVec); + vector2 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))), + yVec); + vector3 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))), + yVec); + vector4 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))), + yVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))), + yVec); + vector2 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))), + yVec); + vector3 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))), + yVec); + vector4 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))), + yVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) { - // Loop handling one vector at a time. - do + case 8: { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - AsVector(ref yRef, i), - AsVector(ref zRef, i)); - - i += Vector.Count; + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 7: { - int lastVectorIndex = x.Length - Vector.Count; - ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); - dest = Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - dest, - op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex), - AsVector(ref zRef, lastVectorIndex))); + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; } - return; - } - } + case 6: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } - // Loop handling one element at a time. - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - Unsafe.Add(ref zRef, i)); + case 5: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } - i++; - } - } + case 4: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } - /// - /// Performs an element-wise operation on , , and , - /// and writes the results to . - /// - /// - /// Specifies the operation to perform on the pair-wise elements loaded from and - /// with . - /// - private static void InvokeSpanSpanScalarIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, float z, Span destination, TTernaryOperator op = default) - where TTernaryOperator : struct, ITernaryOperator - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + case 3: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + case 2: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } - ValidateInputOutputSpanNonOverlapping(x, destination); - ValidateInputOutputSpanNonOverlapping(y, destination); + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } - if (Vector.IsHardwareAccelerated) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, float y, ref float dRef, nuint remainder, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + switch (remainder) { - Vector zVec = new(z); + case 7: + { + Unsafe.Add(ref dRef, 6) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 6)), + y); + goto case 6; + } - // Loop handling one vector at a time. - do + case 6: { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - AsVector(ref yRef, i), - zVec); + Unsafe.Add(ref dRef, 5) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 5)), + y); + goto case 5; + } - i += Vector.Count; + case 5: + { + Unsafe.Add(ref dRef, 4) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 4)), + y); + goto case 4; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 4: { - int lastVectorIndex = x.Length - Vector.Count; - ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); - dest = Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - dest, - op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex), - zVec)); + Unsafe.Add(ref dRef, 3) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 3)), + y); + goto case 3; } - return; - } - } + case 3: + { + Unsafe.Add(ref dRef, 2) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } - // Loop handling one element at a time. - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - z); + case 2: + { + Unsafe.Add(ref dRef, 1) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = binaryOp.Invoke(xTransformOp.Invoke(xRef), y); + goto case 0; + } - i++; + case 0: + { + break; + } + } } } @@ -896,14 +1239,14 @@ private static void InvokeSpanSpanScalarIntoSpan( /// and writes the results to . /// /// - /// Specifies the operation to perform on the pair-wise element loaded from , with , - /// and the element loaded from . + /// Specifies the operation to perform on the pair-wise elements loaded from , , + /// and . /// - private static void InvokeSpanScalarSpanIntoSpan( - ReadOnlySpan x, float y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) + private static unsafe void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) where TTernaryOperator : struct, ITernaryOperator { - if (x.Length != z.Length) + if (x.Length != y.Length || x.Length != z.Length) { ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); } @@ -914,56 +1257,1009 @@ private static void InvokeSpanScalarSpanIntoSpan( } ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); ValidateInputOutputSpanNonOverlapping(z, destination); ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); ref float zRef = ref MemoryMarshal.GetReference(z); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); if (Vector.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else { - Vector yVec = new(y); + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - // Loop handling one vector at a time. - do - { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - yVec, - AsVector(ref zRef, i)); + VectorizedSmall(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } - i += Vector.Count; - } - while (i <= oneVectorFromEnd); + return; + } - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - int lastVectorIndex = x.Length - Vector.Count; - ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); - dest = Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - dest, - op.Invoke(AsVector(ref xRef, lastVectorIndex), - yVec, - AsVector(ref zRef, lastVectorIndex))); - } + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref zRef, ref dRef, remainder); - return; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)), + Unsafe.Add(ref zRef, (nint)(i))); } } - // Loop handling one element at a time. - while (i < x.Length) + static void Vectorized(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - y, - Unsafe.Add(ref zRef, i)); + ref float dRefBeg = ref dRef; - i++; + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef), + AsVector(ref zRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count)), + AsVector(ref zRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0)), + *(Vector*)(zPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1)), + *(Vector*)(zPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2)), + *(Vector*)(zPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3)), + *(Vector*)(zPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4)), + *(Vector*)(zPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5)), + *(Vector*)(zPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6)), + *(Vector*)(zPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7)), + *(Vector*)(zPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + zPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6), + Unsafe.Add(ref zRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5), + Unsafe.Add(ref zRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4), + Unsafe.Add(ref zRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3), + Unsafe.Add(ref zRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef, zRef); + break; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and + /// with . + /// + private static unsafe void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, z, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, float z, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)), + z); + } + } + + static void Vectorized(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector zVec = new Vector(z); + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef), + zVec); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count)), + zVec); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0)), + zVec); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1)), + zVec); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2)), + zVec); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3)), + zVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4)), + zVec); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5)), + zVec); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6)), + zVec); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7)), + zVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6), + z); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5), + z); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4), + z); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3), + z); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise element loaded from , with , + /// and the element loaded from . + /// + private static unsafe void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float zRef, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + y, + Unsafe.Add(ref zRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector yVec = new Vector(y); + + Vector beg = op.Invoke(AsVector(ref xRef), + yVec, + AsVector(ref zRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + zPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + y, + Unsafe.Add(ref zRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + y, + Unsafe.Add(ref zRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + y, + Unsafe.Add(ref zRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + y, + Unsafe.Add(ref zRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + Debug.Assert(remainder == 0); + break; + } + } } }