From bae2b67a37f415ab9b418dde7bfe524f5c85fa5e Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Wed, 11 Oct 2023 10:38:51 -0700 Subject: [PATCH 1/9] Update InvokeSpanSpanIntoSpan for TensorPrimitives to use the better SIMD algorithm --- .../System.Numerics.Tensors.sln | 24 +- .../Tensors/TensorPrimitives.netcore.cs | 1034 +++++++++++++++-- .../Tensors/TensorPrimitives.netstandard.cs | 299 ++++- 3 files changed, 1254 insertions(+), 103 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln b/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln index 5a1bf792736d94..dc07cb0b831459 100644 --- a/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln +++ b/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln @@ -1,4 +1,8 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.8.34205.153 +MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestUtilities", "..\Common\tests\TestUtilities\TestUtilities.csproj", "{9F20CEA1-2216-4432-BBBD-F01E05D17F23}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bcl.Numerics", "..\Microsoft.Bcl.Numerics\ref\Microsoft.Bcl.Numerics.csproj", "{D311ABE4-10A9-4BB1-89CE-6358C55501A8}" @@ -33,11 +37,11 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{AB415F5A-75E EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{083161E5-6049-4D84-9739-9D7797D7117D}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "tools\gen", "{841A2FA4-A95F-4612-A8B9-AD2EF769BC71}" +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{841A2FA4-A95F-4612-A8B9-AD2EF769BC71}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "tools\src", "{DF0561A1-3AB8-4B51-AFB4-392EE1DD6247}" +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DF0561A1-3AB8-4B51-AFB4-392EE1DD6247}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "tools\ref", "{7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB}" +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B}" EndProject @@ -105,23 +109,27 @@ Global EndGlobalSection GlobalSection(NestedProjects) = preSolution {9F20CEA1-2216-4432-BBBD-F01E05D17F23} = {DE94CA7D-BB10-4865-85A6-6B694631247F} - {4AF6A02D-82C8-4898-9EDF-01F107C25061} = {DE94CA7D-BB10-4865-85A6-6B694631247F} {D311ABE4-10A9-4BB1-89CE-6358C55501A8} = {6BC42E6D-848C-4533-B715-F116E7DB3610} - {21CB448A-3882-4337-B416-D1A3E0BCFFC5} = {6BC42E6D-848C-4533-B715-F116E7DB3610} {1578185F-C4FA-4866-936B-E62AAEDD03B7} = {AB415F5A-75E5-4E03-8A92-15CEDEC4CD3A} + {21CB448A-3882-4337-B416-D1A3E0BCFFC5} = {6BC42E6D-848C-4533-B715-F116E7DB3610} {848DD000-3D22-4A25-A9D9-05AFF857A116} = {AB415F5A-75E5-4E03-8A92-15CEDEC4CD3A} + {4AF6A02D-82C8-4898-9EDF-01F107C25061} = {DE94CA7D-BB10-4865-85A6-6B694631247F} {4588351F-4233-4957-B84C-7F8E22B8888A} = {083161E5-6049-4D84-9739-9D7797D7117D} {DB954E01-898A-4FE2-A3AA-180D041AB08F} = {083161E5-6049-4D84-9739-9D7797D7117D} {04FC0651-B9D0-448A-A28B-11B1D4A897F4} = {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} {683A7D28-CC55-4375-848D-E659075ECEE4} = {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} - {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} {1CBEAEA8-2CA1-4B07-9930-35A785205852} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} {BA7828B1-7953-47A0-AE5A-E22B501C4BD0} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} - {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} {57E57290-3A6A-43F8-8764-D4DC8151F89C} = {7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB} + {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} + {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} {7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {10A5F2C3-5230-4916-9D4D-BBDB94851037} EndGlobalSection + GlobalSection(SharedMSBuildProjectFiles) = preSolution + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{683a7d28-cc55-4375-848d-e659075ecee4}*SharedItemsImports = 5 + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{ba7828b1-7953-47a0-ae5a-e22b501c4bd0}*SharedItemsImports = 5 + EndGlobalSection EndGlobal diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 8166797f348969..9ae6ccee7f1ad9 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -2138,110 +2138,1002 @@ private static void InvokeSpanSpanIntoSpan( ValidateInputOutputSpanNonOverlapping(x, destination); ValidateInputOutputSpanNonOverlapping(y, destination); + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); #if NET8_0_OR_GREATER if (Vector512.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) + if (remainder >= (uint)(Vector512.Count)) { - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); - - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.LoadUnsafe(ref dRef, lastVectorIndex), - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); - } + Vectorized512(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - return; + Vectorized512Small(ref xRef, ref yRef, ref dRef, remainder); } + + return; } #endif if (Vector256.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) + if (remainder >= (uint)(Vector256.Count)) { - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + Vectorized256(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + Vectorized256Small(ref xRef, ref yRef, ref dRef, remainder); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.LoadUnsafe(ref dRef, lastVectorIndex), - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); - } + return; + } - return; + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, ref dRef, remainder); } + + return; } - if (Vector128.IsHardwareAccelerated) + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float dRef, nuint length) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) + for (nuint i = 0; i < length; i++) { - // Loop handling one vector at a time. - do + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) { - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; - i += Vector128.Count; + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.LoadUnsafe(ref dRef, lastVectorIndex), - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; } - return; - } - } + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i)); + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } - i++; - } + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef)); + Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } +#endif } /// diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index cde38a70fbbef4..b51fc5289bcf15 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -597,7 +597,7 @@ static void VectorizedSmall(ref float xRef, ref float dRef, nuint remainder, TUn /// /// Specifies the operation to perform on the pair-wise elements loaded from and . /// - private static void InvokeSpanSpanIntoSpan( + private static unsafe void InvokeSpanSpanIntoSpan( ReadOnlySpan x, ReadOnlySpan y, Span destination, TBinaryOperator op = default) where TBinaryOperator : struct, IBinaryOperator { @@ -614,48 +614,299 @@ private static void InvokeSpanSpanIntoSpan( ValidateInputOutputSpanNonOverlapping(x, destination); ValidateInputOutputSpanNonOverlapping(y, destination); + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); if (Vector.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + if (remainder >= (uint)(Vector.Count)) { - // Loop handling one vector at a time. - do + Vectorized(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float dRef, nuint length, TBinaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float yRef, ref float dRef, nuint remainder, TBinaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - AsVector(ref yRef, i)); + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; - i += Vector.Count; + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: { - int lastVectorIndex = x.Length - Vector.Count; - ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); - dest = Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - dest, - op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex))); + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; } - return; + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } } } - while (i < x.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, ref float dRef, nuint remainder, TBinaryOperator op = default) { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i)); + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6)); + goto case 6; + } - i++; + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } } } From 0f5fb2a3feee007e167369a4a3f41d07e5661c4e Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Wed, 11 Oct 2023 10:56:40 -0700 Subject: [PATCH 2/9] Update InvokeSpanScalarIntoSpan for TensorPrimitives to use the better SIMD algorithm --- .../Tensors/TensorPrimitives.netcore.cs | 1028 +++++++++++++++-- .../Tensors/TensorPrimitives.netstandard.cs | 299 ++++- 2 files changed, 1227 insertions(+), 100 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 9ae6ccee7f1ad9..b72c89b770debb 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -3171,115 +3171,995 @@ private static void InvokeSpanScalarIntoSpan.Count; - if (i <= oneVectorFromEnd) + if (remainder >= (uint)(Vector512.Count)) { - Vector512 yVec = Vector512.Create(y); + Vectorized512(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i)), - yVec).StoreUnsafe(ref dRef, (uint)i); + Vectorized512Small(ref xRef, y, ref dRef, remainder); + } - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + return; + } +#endif - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.LoadUnsafe(ref dRef, lastVectorIndex), - TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex)), - yVec)).StoreUnsafe(ref dRef, lastVectorIndex); - } + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - return; + Vectorized256Small(ref xRef, y, ref dRef, remainder); } + + return; } -#endif - if (Vector256.IsHardwareAccelerated) + if (Vector128.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) + if (remainder >= (uint)(Vector128.Count)) { - Vector256 yVec = Vector256.Create(y); + Vectorized128(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i)), - yVec).StoreUnsafe(ref dRef, (uint)i); + Vectorized128Small(ref xRef, y, ref dRef, remainder); + } - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + return; + } - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.LoadUnsafe(ref dRef, lastVectorIndex), - TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex)), - yVec)).StoreUnsafe(ref dRef, lastVectorIndex); - } + // This is the software fallback when no acceleration is available + // It requires no branches to hit - return; + SoftwareFallback(ref xRef, y, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, i)), + y); } } - if (Vector128.IsHardwareAccelerated) + static void Vectorized128(ref float xRef, float y, ref float dRef, nuint remainder) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + if (remainder > (uint)(Vector128.Count * 8)) { - Vector128 yVec = Vector128.Create(y); + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - // Loop handling one vector at a time. - do + fixed (float* px = &xRef) + fixed (float* pd = &dRef) { - TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i)), - yVec).StoreUnsafe(ref dRef, (uint)i); + float* xPtr = px; + float* dPtr = pd; - i += Vector128.Count; + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.LoadUnsafe(ref dRef, lastVectorIndex), - TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex)), - yVec)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; } - return; - } - } + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, i)), - y); + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } - i++; - } + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } + + static void Vectorized256(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + yVec); + Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))), + yVec); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + Vector128.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 yVec = Vector512.Create(y); + + Vector512 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)), + yVec); + Vector512 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))), + yVec); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + yVec); + Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + Vector256.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + Vector128.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } +#endif } /// diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index b51fc5289bcf15..b8f444b58b96c8 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -933,7 +933,7 @@ private static void InvokeSpanScalarIntoSpan( /// /// Specifies the operation to perform on the transformed value from with . /// - private static void InvokeSpanScalarIntoSpan( + private static unsafe void InvokeSpanScalarIntoSpan( ReadOnlySpan x, float y, Span destination, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) where TTransformOperator : struct, IUnaryOperator where TBinaryOperator : struct, IBinaryOperator @@ -945,48 +945,295 @@ private static void InvokeSpanScalarIntoSpan.Count; - if (oneVectorFromEnd >= 0) + if (remainder >= (uint)(Vector.Count)) { - // Loop handling one vector at a time. - Vector yVec = new(y); - do + Vectorized(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, y, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float dRef, nuint length, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, (nint)(i))), + y); + } + } + + static void Vectorized(ref float xRef, float y, ref float dRef, nuint remainder, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector yVec = new Vector(y); + + Vector beg = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef)), + yVec); + Vector end = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))), + yVec); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) { - AsVector(ref dRef, i) = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, i)), - yVec); + float* xPtr = px; + float* dPtr = pd; - i += Vector.Count; + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))), + yVec); + vector2 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))), + yVec); + vector3 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))), + yVec); + vector4 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))), + yVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))), + yVec); + vector2 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))), + yVec); + vector3 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))), + yVec); + vector4 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))), + yVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: { - int lastVectorIndex = x.Length - Vector.Count; - ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); - dest = Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - dest, - binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, lastVectorIndex)), yVec)); + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; } - return; + case 7: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } } } - // Loop handling one element at a time. - while (i < x.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, float y, ref float dRef, nuint remainder, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) { - Unsafe.Add(ref dRef, i) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, i)), - y); + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 6)), + y); + goto case 6; + } - i++; + case 6: + { + Unsafe.Add(ref dRef, 5) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 5)), + y); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 4)), + y); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 3)), + y); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = binaryOp.Invoke(xTransformOp.Invoke(xRef), y); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } } } From 5059d0189b0837f0e2dab21a3beb3e890d61e1ec Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Wed, 11 Oct 2023 11:24:44 -0700 Subject: [PATCH 3/9] Update InvokeSpanSpanSpanIntoSpan for TensorPrimitives to use the better SIMD algorithm --- .../Tensors/TensorPrimitives.netcore.cs | 1146 +++++++++++++++-- .../Tensors/TensorPrimitives.netstandard.cs | 327 ++++- 2 files changed, 1369 insertions(+), 104 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index b72c89b770debb..6ce03d81095916 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -4188,118 +4188,1112 @@ private static void InvokeSpanSpanSpanIntoSpan( ValidateInputOutputSpanNonOverlapping(y, destination); ValidateInputOutputSpanNonOverlapping(z, destination); + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float zRef = ref MemoryMarshal.GetReference(z); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); #if NET8_0_OR_GREATER if (Vector512.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) + if (remainder >= (uint)(Vector512.Count)) { - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i), - Vector512.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); - - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex), - Vector512.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); - } + Vectorized512(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - return; + Vectorized512Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); } + + return; } #endif if (Vector256.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) + if (remainder >= (uint)(Vector256.Count)) { - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i), - Vector256.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + Vectorized256(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + Vectorized256Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex), - Vector256.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); - } + return; + } - return; + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); } + + return; } - if (Vector128.IsHardwareAccelerated) + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint length) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) + for (nuint i = 0; i < length; i++) { - // Loop handling one vector at a time. - do + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + Unsafe.Add(ref zRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) { - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i), - Vector128.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; - i += Vector128.Count; + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex), - Vector128.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; } - return; - } - } + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } - while (i < x.Length) + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef), + Vector512.LoadUnsafe(ref zRef)); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - Unsafe.Add(ref zRef, i)); + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); - i++; + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } } +#endif } /// diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index b8f444b58b96c8..717857a2961aa6 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -1245,7 +1245,7 @@ static void VectorizedSmall(ref float xRef, float y, ref float dRef, nuint remai /// Specifies the operation to perform on the pair-wise elements loaded from , , /// and . /// - private static void InvokeSpanSpanSpanIntoSpan( + private static unsafe void InvokeSpanSpanSpanIntoSpan( ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) where TTernaryOperator : struct, ITernaryOperator { @@ -1267,49 +1267,320 @@ private static void InvokeSpanSpanSpanIntoSpan( ref float yRef = ref MemoryMarshal.GetReference(y); ref float zRef = ref MemoryMarshal.GetReference(z); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); if (Vector.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + if (remainder >= (uint)(Vector.Count)) { - // Loop handling one vector at a time. - do + Vectorized(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)), + Unsafe.Add(ref zRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef), + AsVector(ref zRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count)), + AsVector(ref zRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - AsVector(ref yRef, i), - AsVector(ref zRef, i)); + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; - i += Vector.Count; + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0)), + *(Vector*)(zPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1)), + *(Vector*)(zPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2)), + *(Vector*)(zPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3)), + *(Vector*)(zPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4)), + *(Vector*)(zPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5)), + *(Vector*)(zPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6)), + *(Vector*)(zPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7)), + *(Vector*)(zPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + zPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: { - int lastVectorIndex = x.Length - Vector.Count; - ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); - dest = Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - dest, - op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex), - AsVector(ref zRef, lastVectorIndex))); + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; } - return; + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } } } - // Loop handling one element at a time. - while (i < x.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - Unsafe.Add(ref zRef, i)); + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6), + Unsafe.Add(ref zRef, 6)); + goto case 6; + } - i++; + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5), + Unsafe.Add(ref zRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4), + Unsafe.Add(ref zRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3), + Unsafe.Add(ref zRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef, zRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } } } From 520a3e1fd26208283c86ac89093340b72649bd70 Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Wed, 11 Oct 2023 11:56:05 -0700 Subject: [PATCH 4/9] Update InvokeSpanSpanScalarIntoSpan for TensorPrimitives to use the better SIMD algorithm --- .../Tensors/TensorPrimitives.netcore.cs | 1142 +++++++++++++++-- .../Tensors/TensorPrimitives.netstandard.cs | 324 ++++- 2 files changed, 1357 insertions(+), 109 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 6ce03d81095916..7567f8a9090cc4 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -5321,123 +5321,1105 @@ private static void InvokeSpanSpanScalarIntoSpan( ValidateInputOutputSpanNonOverlapping(x, destination); ValidateInputOutputSpanNonOverlapping(y, destination); + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) - { - Vector512 zVec = Vector512.Create(z); + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, z, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, float z, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + z); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 zVec = Vector256.Create(z); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + zVec); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + zVec); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 zVec = Vector512.Create(z); + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef), + zVec); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count)), + zVec); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + zVec); - // Loop handling one vector at a time. - do + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: { - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i), - zVec).StoreUnsafe(ref dRef, (uint)i); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } - i += Vector512.Count; + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 6: { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex), - zVec)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; } - return; - } - } -#endif + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } - if (Vector256.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) - { - Vector256 zVec = Vector256.Create(z); + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } - // Loop handling one vector at a time. - do + case 3: { - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i), - zVec).StoreUnsafe(ref dRef, (uint)i); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } - i += Vector256.Count; + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 1: { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex), - zVec)).StoreUnsafe(ref dRef, lastVectorIndex); + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto default; } - return; + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } } } - if (Vector128.IsHardwareAccelerated) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) + switch (remainder) { - Vector128 zVec = Vector128.Create(z); + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); - // Loop handling one vector at a time. - do + Vector256 zVec = Vector256.Create(z); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + zVec); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: { - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i), - zVec).StoreUnsafe(ref dRef, (uint)i); + Debug.Assert(Vector256.IsHardwareAccelerated); - i += Vector128.Count; + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 7: + case 6: + case 5: { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex), - zVec)).StoreUnsafe(ref dRef, lastVectorIndex); + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; } - return; - } - } + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - z); + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); + beg.StoreUnsafe(ref dRef); - i++; + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } } +#endif } /// diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index 717857a2961aa6..cb5617f76f7c0b 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -1592,7 +1592,7 @@ static void VectorizedSmall(ref float xRef, ref float yRef, ref float zRef, ref /// Specifies the operation to perform on the pair-wise elements loaded from and /// with . /// - private static void InvokeSpanSpanScalarIntoSpan( + private static unsafe void InvokeSpanSpanScalarIntoSpan( ReadOnlySpan x, ReadOnlySpan y, float z, Span destination, TTernaryOperator op = default) where TTernaryOperator : struct, ITernaryOperator { @@ -1612,51 +1612,317 @@ private static void InvokeSpanSpanScalarIntoSpan( ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); if (Vector.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, z, ref dRef, remainder); + } + else { - Vector zVec = new(z); + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - // Loop handling one vector at a time. - do + VectorizedSmall(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, z, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, float z, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)), + z); + } + } + + static void Vectorized(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector zVec = new Vector(z); + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef), + zVec); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count)), + zVec); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - AsVector(ref yRef, i), - zVec); + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; - i += Vector.Count; + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0)), + zVec); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1)), + zVec); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2)), + zVec); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3)), + zVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4)), + zVec); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5)), + zVec); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6)), + zVec); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7)), + zVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: { - int lastVectorIndex = x.Length - Vector.Count; - ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); - dest = Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - dest, - op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex), - zVec)); + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; } - return; + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } } } - // Loop handling one element at a time. - while (i < x.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder, TTernaryOperator op = default) { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - z); + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6), + z); + goto case 6; + } - i++; + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5), + z); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4), + z); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3), + z); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef, z); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } } } From b1dec16002b52f4c1d56b7250813b380c84d993b Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Wed, 11 Oct 2023 12:11:15 -0700 Subject: [PATCH 5/9] Update InvokeSpanScalarSpanIntoSpan for TensorPrimitives to use the better SIMD algorithm --- .../Tensors/TensorPrimitives.netcore.cs | 1142 +++++++++++++++-- .../Tensors/TensorPrimitives.netstandard.cs | 324 ++++- 2 files changed, 1357 insertions(+), 109 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 7567f8a9090cc4..dfbbc4e5d7e299 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -6447,123 +6447,1105 @@ private static void InvokeSpanScalarSpanIntoSpan( ValidateInputOutputSpanNonOverlapping(x, destination); ValidateInputOutputSpanNonOverlapping(z, destination); + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + ref float xRef = ref MemoryMarshal.GetReference(x); ref float zRef = ref MemoryMarshal.GetReference(z); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) - { - Vector512 yVec = Vector512.Create(y); + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float zRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + y, + Unsafe.Add(ref zRef, i)); + } + } + + static void Vectorized128(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } + + static void Vectorized256(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + yVec, + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.Create(y), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 yVec = Vector512.Create(y); + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + yVec, + Vector512.LoadUnsafe(ref zRef)); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); - // Loop handling one vector at a time. - do + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: { - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector512.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } - i += Vector512.Count; + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 6: { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector512.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; } - return; - } - } -#endif + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } - if (Vector256.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) - { - Vector256 yVec = Vector256.Create(y); + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } - // Loop handling one vector at a time. - do + case 3: { - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector256.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } - i += Vector256.Count; + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 1: { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector256.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto default; } - return; + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } } } - if (Vector128.IsHardwareAccelerated) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) + switch (remainder) { - Vector128 yVec = Vector128.Create(y); + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); - // Loop handling one vector at a time. - do + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + yVec, + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.Create(y), + Vector256.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: { - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector128.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); - i += Vector128.Count; + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 4: { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.LoadUnsafe(ref dRef, lastVectorIndex), - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector128.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.Create(y), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; } - return; - } - } + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), - y, - Unsafe.Add(ref zRef, i)); + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } - i++; + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } } +#endif } /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index cb5617f76f7c0b..ff39976a709b40 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -1934,7 +1934,7 @@ static void VectorizedSmall(ref float xRef, ref float yRef, float z, ref float d /// Specifies the operation to perform on the pair-wise element loaded from , with , /// and the element loaded from . /// - private static void InvokeSpanScalarSpanIntoSpan( + private static unsafe void InvokeSpanScalarSpanIntoSpan( ReadOnlySpan x, float y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) where TTernaryOperator : struct, ITernaryOperator { @@ -1954,51 +1954,317 @@ private static void InvokeSpanScalarSpanIntoSpan( ref float xRef = ref MemoryMarshal.GetReference(x); ref float zRef = ref MemoryMarshal.GetReference(z); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); if (Vector.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float zRef, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) { - Vector yVec = new(y); + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + y, + Unsafe.Add(ref zRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector yVec = new Vector(y); - // Loop handling one vector at a time. - do + Vector beg = op.Invoke(AsVector(ref xRef), + yVec, + AsVector(ref zRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - yVec, - AsVector(ref zRef, i)); + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; - i += Vector.Count; + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + zPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: { - int lastVectorIndex = x.Length - Vector.Count; - ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); - dest = Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - dest, - op.Invoke(AsVector(ref xRef, lastVectorIndex), - yVec, - AsVector(ref zRef, lastVectorIndex))); + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; } - return; + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } } } - // Loop handling one element at a time. - while (i < x.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - y, - Unsafe.Add(ref zRef, i)); + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + y, + Unsafe.Add(ref zRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + y, + Unsafe.Add(ref zRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + y, + Unsafe.Add(ref zRef, 4)); + goto case 4; + } - i++; + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + y, + Unsafe.Add(ref zRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, y, zRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } } } From 58a9047032ebf467dfd46107cb32374a8a4640d1 Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Wed, 11 Oct 2023 14:20:51 -0700 Subject: [PATCH 6/9] Improve codegen slightly by using case 0, rather than default --- .../Tensors/TensorPrimitives.netcore.cs | 162 ++++++++---------- .../Tensors/TensorPrimitives.netstandard.cs | 51 +++--- 2 files changed, 95 insertions(+), 118 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index dfbbc4e5d7e299..609c397dbc276d 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -1505,10 +1505,10 @@ static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -1537,12 +1537,11 @@ static void Vectorized128Small(ref float xRef, ref float dRef, nuint remainder) case 1: { dRef = TUnaryOperator.Invoke(xRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -1747,10 +1746,10 @@ static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -1804,12 +1803,11 @@ static void Vectorized256Small(ref float xRef, ref float dRef, nuint remainder) case 1: { dRef = TUnaryOperator.Invoke(xRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -2015,10 +2013,10 @@ static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -2101,12 +2099,11 @@ static void Vectorized512Small(ref float xRef, ref float dRef, nuint remainder) case 1: { dRef = TUnaryOperator.Invoke(xRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -2450,10 +2447,10 @@ static void Vectorized128(ref float xRef, ref float yRef, ref float dRef, nuint { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -2484,12 +2481,11 @@ static void Vectorized128Small(ref float xRef, ref float yRef, ref float dRef, n case 1: { dRef = TBinaryOperator.Invoke(xRef, yRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -2725,10 +2721,10 @@ static void Vectorized256(ref float xRef, ref float yRef, ref float dRef, nuint { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -2787,12 +2783,11 @@ static void Vectorized256Small(ref float xRef, ref float yRef, ref float dRef, n case 1: { dRef = TBinaryOperator.Invoke(xRef, yRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -3029,10 +3024,10 @@ static void Vectorized512(ref float xRef, ref float yRef, ref float dRef, nuint { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -3123,12 +3118,11 @@ static void Vectorized512Small(ref float xRef, ref float yRef, ref float dRef, n case 1: { dRef = TBinaryOperator.Invoke(xRef, yRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -3478,10 +3472,10 @@ static void Vectorized128(ref float xRef, float y, ref float dRef, nuint remaind { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -3512,12 +3506,11 @@ static void Vectorized128Small(ref float xRef, float y, ref float dRef, nuint re case 1: { dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -3749,10 +3742,10 @@ static void Vectorized256(ref float xRef, float y, ref float dRef, nuint remaind { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -3813,12 +3806,11 @@ static void Vectorized256Small(ref float xRef, float y, ref float dRef, nuint re case 1: { dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -4051,10 +4043,10 @@ static void Vectorized512(ref float xRef, float y, ref float dRef, nuint remaind { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -4149,12 +4141,11 @@ static void Vectorized512Small(ref float xRef, float y, ref float dRef, nuint re case 1: { dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -4533,10 +4524,10 @@ static void Vectorized128(ref float xRef, ref float yRef, ref float zRef, ref fl { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -4569,12 +4560,11 @@ static void Vectorized128Small(ref float xRef, ref float yRef, ref float zRef, r case 1: { dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -4841,10 +4831,10 @@ static void Vectorized256(ref float xRef, ref float yRef, ref float zRef, ref fl { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -4908,12 +4898,11 @@ static void Vectorized256Small(ref float xRef, ref float yRef, ref float zRef, r case 1: { dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -5181,10 +5170,10 @@ static void Vectorized512(ref float xRef, ref float yRef, ref float zRef, ref fl { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -5283,12 +5272,11 @@ static void Vectorized512Small(ref float xRef, ref float yRef, ref float zRef, r case 1: { dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -5661,10 +5649,10 @@ static void Vectorized128(ref float xRef, ref float yRef, float z, ref float dRe { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -5697,12 +5685,11 @@ static void Vectorized128Small(ref float xRef, ref float yRef, float z, ref floa case 1: { dRef = TTernaryOperator.Invoke(xRef, yRef, z); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -5965,10 +5952,10 @@ static void Vectorized256(ref float xRef, ref float yRef, float z, ref float dRe { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -6034,12 +6021,11 @@ static void Vectorized256Small(ref float xRef, ref float yRef, float z, ref floa case 1: { dRef = TTernaryOperator.Invoke(xRef, yRef, z); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -6303,10 +6289,10 @@ static void Vectorized512(ref float xRef, ref float yRef, float z, ref float dRe { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -6409,12 +6395,11 @@ static void Vectorized512Small(ref float xRef, ref float yRef, float z, ref floa case 1: { dRef = TTernaryOperator.Invoke(xRef, yRef, z); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -6787,10 +6772,10 @@ static void Vectorized128(ref float xRef, float y, ref float zRef, ref float dRe { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -6823,12 +6808,11 @@ static void Vectorized128Small(ref float xRef, float y, ref float zRef, ref floa case 1: { dRef = TTernaryOperator.Invoke(xRef, y, zRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -7091,10 +7075,10 @@ static void Vectorized256(ref float xRef, float y, ref float zRef, ref float dRe { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -7160,12 +7144,11 @@ static void Vectorized256Small(ref float xRef, float y, ref float zRef, ref floa case 1: { dRef = TTernaryOperator.Invoke(xRef, y, zRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -7429,10 +7412,10 @@ static void Vectorized512(ref float xRef, float y, ref float zRef, ref float dRe { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); @@ -7535,12 +7518,11 @@ static void Vectorized512Small(ref float xRef, float y, ref float zRef, ref floa case 1: { dRef = TTernaryOperator.Invoke(xRef, y, zRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index ff39976a709b40..9dab69e47d4c92 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -522,10 +522,10 @@ static void Vectorized(ref float xRef, ref float dRef, nuint remainder, TUnaryOp { // Store the last block, which includes any elements that wouldn't fill a full vector AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block AsVector(ref dRefBeg) = beg; @@ -578,12 +578,11 @@ static void VectorizedSmall(ref float xRef, ref float dRef, nuint remainder, TUn case 1: { dRef = op.Invoke(xRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -836,10 +835,10 @@ static void Vectorized(ref float xRef, ref float yRef, ref float dRef, nuint rem { // Store the last block, which includes any elements that wouldn't fill a full vector AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block AsVector(ref dRefBeg) = beg; @@ -898,12 +897,11 @@ static void VectorizedSmall(ref float xRef, ref float yRef, ref float dRef, nuin case 1: { dRef = op.Invoke(xRef, yRef); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -1163,10 +1161,10 @@ static void Vectorized(ref float xRef, float y, ref float dRef, nuint remainder, { // Store the last block, which includes any elements that wouldn't fill a full vector AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block AsVector(ref dRefBeg) = beg; @@ -1225,12 +1223,11 @@ static void VectorizedSmall(ref float xRef, float y, ref float dRef, nuint remai case 1: { dRef = binaryOp.Invoke(xTransformOp.Invoke(xRef), y); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -1504,10 +1501,10 @@ static void Vectorized(ref float xRef, ref float yRef, ref float zRef, ref float { // Store the last block, which includes any elements that wouldn't fill a full vector AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block AsVector(ref dRefBeg) = beg; @@ -1575,9 +1572,8 @@ static void VectorizedSmall(ref float xRef, ref float yRef, ref float zRef, ref break; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -1846,10 +1842,10 @@ static void Vectorized(ref float xRef, ref float yRef, float z, ref float dRef, { // Store the last block, which includes any elements that wouldn't fill a full vector AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block AsVector(ref dRefBeg) = beg; @@ -1914,12 +1910,11 @@ static void VectorizedSmall(ref float xRef, ref float yRef, float z, ref float d case 1: { dRef = op.Invoke(xRef, yRef, z); - break; + goto case 0; } - default: + case 0: { - Debug.Assert(remainder == 0); break; } } @@ -2188,10 +2183,10 @@ static void Vectorized(ref float xRef, float y, ref float zRef, ref float dRef, { // Store the last block, which includes any elements that wouldn't fill a full vector AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; - goto default; + goto case 0; } - default: + case 0: { // Store the first block, which includes any elements preceding the first aligned block AsVector(ref dRefBeg) = beg; @@ -2256,10 +2251,10 @@ static void VectorizedSmall(ref float xRef, float y, ref float zRef, ref float d case 1: { dRef = op.Invoke(xRef, y, zRef); - break; + goto case 0; } - default: + case 0: { Debug.Assert(remainder == 0); break; From e3e6ae2a5e8cc7040ef2c30a5627d4bdce586643 Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Thu, 12 Oct 2023 10:42:33 -0700 Subject: [PATCH 7/9] Adjust the canAlign check to be latter, to reduce branch count for data under the threshold --- .../Tensors/TensorPrimitives.netcore.cs | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 609c397dbc276d..16006e216e4bf6 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -1355,7 +1355,7 @@ static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) Vector128 vector3; Vector128 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -1596,7 +1596,7 @@ static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) Vector256 vector3; Vector256 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -1863,7 +1863,7 @@ static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) Vector512 vector3; Vector512 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -2271,7 +2271,7 @@ static void Vectorized128(ref float xRef, ref float yRef, ref float dRef, nuint Vector128 vector3; Vector128 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -2545,7 +2545,7 @@ static void Vectorized256(ref float xRef, ref float yRef, ref float dRef, nuint Vector256 vector3; Vector256 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -2848,7 +2848,7 @@ static void Vectorized512(ref float xRef, ref float yRef, ref float dRef, nuint Vector512 vector3; Vector512 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -3299,7 +3299,7 @@ static void Vectorized128(ref float xRef, float y, ref float dRef, nuint remaind Vector128 vector3; Vector128 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -3569,7 +3569,7 @@ static void Vectorized256(ref float xRef, float y, ref float dRef, nuint remaind Vector256 vector3; Vector256 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -3870,7 +3870,7 @@ static void Vectorized512(ref float xRef, float y, ref float dRef, nuint remaind Vector512 vector3; Vector512 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -4322,7 +4322,7 @@ static void Vectorized128(ref float xRef, ref float yRef, ref float zRef, ref fl Vector128 vector3; Vector128 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -4629,7 +4629,7 @@ static void Vectorized256(ref float xRef, ref float yRef, ref float zRef, ref fl Vector256 vector3; Vector256 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -4968,7 +4968,7 @@ static void Vectorized512(ref float xRef, ref float yRef, ref float zRef, ref fl Vector512 vector3; Vector512 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -5450,7 +5450,7 @@ static void Vectorized128(ref float xRef, ref float yRef, float z, ref float dRe Vector128 vector3; Vector128 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -5753,7 +5753,7 @@ static void Vectorized256(ref float xRef, ref float yRef, float z, ref float dRe Vector256 vector3; Vector256 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -6090,7 +6090,7 @@ static void Vectorized512(ref float xRef, ref float yRef, float z, ref float dRe Vector512 vector3; Vector512 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -6573,7 +6573,7 @@ static void Vectorized128(ref float xRef, float y, ref float zRef, ref float dRe Vector128 vector3; Vector128 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -6876,7 +6876,7 @@ static void Vectorized256(ref float xRef, float y, ref float zRef, ref float dRe Vector256 vector3; Vector256 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. @@ -7213,7 +7213,7 @@ static void Vectorized512(ref float xRef, float y, ref float zRef, ref float dRe Vector512 vector3; Vector512 vector4; - if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. From 42a2014c7591894e1340594c4b9bca22cf26314a Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Thu, 12 Oct 2023 10:43:03 -0700 Subject: [PATCH 8/9] Add a comment explaining the NonTemporalByteThreshold --- .../Numerics/Tensors/TensorPrimitives.netcore.cs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 16006e216e4bf6..2e07bda4efad26 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -13,6 +13,20 @@ namespace System.Numerics.Tensors { public static partial class TensorPrimitives { + /// Defines the threshold, in bytes, at which non-temporal stores will be used. + /// + /// A non-temporal store is one that allows the CPU to bypass the cache when writing to memory. + /// + /// This can be beneficial when working with large amounts of memory where the writes would otherwise + /// cause large amounts of repeated updates and evictions. The hardware optimization manuals recommend + /// the threshold to be roughly half the size of the last level of on-die cache -- that is, if you have approximately + /// 4MB of L3 cache per core, you'd want this to be approx. 1-2MB, depending on if hyperthreading was enabled. + /// + /// However, actually computing the amount of L3 cache per core can be tricky or error prone. Native memcpy + /// algorithms use a constant threshold that is typically around 256KB and we match that here for simplicity. This + /// threshold accounts for most processors in the last 10-15 years that had approx. 1MB L3 per core and support + /// hyperthreading, giving a per core last level cache of approx. 512KB. + /// private const nuint NonTemporalByteThreshold = 256 * 1024; /// From f03dc1df90cb17f07e6797903ab68b1ab415e257 Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Mon, 16 Oct 2023 11:31:52 -0700 Subject: [PATCH 9/9] Make sure xTransformOp.CanVectorize is checked on .NET Standard --- .../src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index 9dab69e47d4c92..5e6e9ac6252e3c 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -953,7 +953,7 @@ private static unsafe void InvokeSpanScalarIntoSpan= (uint)(Vector.Count)) {