diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index b31a427139..0294804b29 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -20,6 +20,7 @@ internal static class AvxIntrinsics { private static readonly Vector256 _absMask256 = Avx.StaticCast(Avx.SetAllVector256(0x7FFFFFFF)); + // The count of bytes in Vector256, corresponding to _cbAlign in AlignedArray private const int Vector256Alignment = 32; [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] @@ -415,12 +416,12 @@ public static unsafe void AddScalarU(float scalar, Span dst) { fixed (float* pdst = dst) { - float* pDstEnd = pdst + dst.Length; - float* pDstCurrent = pdst; - Vector256 scalarVector256 = Avx.SetAllVector256(scalar); - while (pDstCurrent + 8 <= pDstEnd) + int count = Math.DivRem(dst.Length, 8, out int remainder); + float* pDstCurrent = pdst; + + for (int i = 0; i < count; i++) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); dstVector = Avx.Add(dstVector, scalarVector256); @@ -431,7 +432,7 @@ public static unsafe void AddScalarU(float scalar, Span dst) Vector128 scalarVector128 = Sse.SetAllVector128(scalar); - if (pDstCurrent + 4 <= pDstEnd) + if (remainder >= 4) { Vector128 dstVector = Sse.LoadVector128(pDstCurrent); dstVector = Sse.Add(dstVector, scalarVector128); @@ -440,13 +441,9 @@ public static unsafe void AddScalarU(float scalar, Span dst) pDstCurrent += 4; } - while (pDstCurrent < pDstEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - dstVector = Sse.AddScalar(dstVector, scalarVector128); - Sse.StoreScalar(pDstCurrent, dstVector); - - pDstCurrent++; + pDstCurrent[i] += scalar; } } } @@ -455,12 +452,12 @@ public static unsafe void ScaleU(float scale, Span dst) { fixed (float* pdst = dst) { - float* pDstCurrent = pdst; - float* pEnd = pdst + dst.Length; - Vector256 scaleVector256 = Avx.SetAllVector256(scale); - while (pDstCurrent + 8 <= pEnd) + int count = Math.DivRem(dst.Length, 8, out int remainder); + float* pDstCurrent = pdst; + + for (int i = 0; i < count; i++) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -472,7 +469,7 @@ public static unsafe void ScaleU(float scale, Span dst) Vector128 scaleVector128 = Sse.SetAllVector128(scale); - if (pDstCurrent + 4 <= pEnd) + if (remainder >= 4) { Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -482,14 +479,9 @@ public static unsafe void ScaleU(float scale, Span dst) pDstCurrent += 4; } - while (pDstCurrent < pEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - - dstVector = Sse.MultiplyScalar(scaleVector128, dstVector); - Sse.StoreScalar(pDstCurrent, dstVector); - - pDstCurrent++; + pDstCurrent[i] *= scale; } } } @@ -499,13 +491,13 @@ public static unsafe void ScaleSrcU(float scale, Span src, Span ds fixed (float* psrc = src) fixed (float* pdst = dst) { - float* pDstEnd = pdst + dst.Length; + Vector256 scaleVector256 = Avx.SetAllVector256(scale); + + int count = Math.DivRem(dst.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDstCurrent = pdst; - Vector256 scaleVector256 = Avx.SetAllVector256(scale); - - while (pDstCurrent + 8 <= pDstEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Multiply(srcVector, scaleVector256); @@ -517,7 +509,7 @@ public static unsafe void ScaleSrcU(float scale, Span src, Span ds Vector128 scaleVector128 = Sse.SetAllVector128(scale); - if (pDstCurrent + 4 <= pDstEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Multiply(srcVector, scaleVector128); @@ -527,14 +519,9 @@ public static unsafe void ScaleSrcU(float scale, Span src, Span ds pDstCurrent += 4; } - while (pDstCurrent < pDstEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); - srcVector = Sse.MultiplyScalar(srcVector, scaleVector128); - Sse.StoreScalar(pDstCurrent, srcVector); - - pSrcCurrent++; - pDstCurrent++; + pDstCurrent[i] = pSrcCurrent[i] * scale; } } } @@ -544,13 +531,13 @@ public static unsafe void ScaleAddU(float a, float b, Span dst) { fixed (float* pdst = dst) { - float* pDstEnd = pdst + dst.Length; - float* pDstCurrent = pdst; - Vector256 a256 = Avx.SetAllVector256(a); Vector256 b256 = Avx.SetAllVector256(b); - while (pDstCurrent + 8 <= pDstEnd) + int count = Math.DivRem(dst.Length, 8, out int remainder); + float* pDstCurrent = pdst; + + for (int i = 0; i < count; i++) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); dstVector = Avx.Add(dstVector, b256); @@ -563,7 +550,7 @@ public static unsafe void ScaleAddU(float a, float b, Span dst) Vector128 a128 = Sse.SetAllVector128(a); Vector128 b128 = Sse.SetAllVector128(b); - if (pDstCurrent + 4 <= pDstEnd) + if (remainder >= 4) { Vector128 dstVector = Sse.LoadVector128(pDstCurrent); dstVector = Sse.Add(dstVector, b128); @@ -573,14 +560,9 @@ public static unsafe void ScaleAddU(float a, float b, Span dst) pDstCurrent += 4; } - while (pDstCurrent < pDstEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - dstVector = Sse.AddScalar(dstVector, b128); - dstVector = Sse.MultiplyScalar(dstVector, a128); - Sse.StoreScalar(pDstCurrent, dstVector); - - pDstCurrent++; + pDstCurrent[i] = a * (pDstCurrent[i] + b); } } } @@ -590,13 +572,13 @@ public static unsafe void AddScaleU(float scale, Span src, Span ds fixed (float* psrc = src) fixed (float* pdst = dst) { + Vector256 scaleVector256 = Avx.SetAllVector256(scale); + + int count = Math.DivRem(dst.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDstCurrent = pdst; - float* pEnd = pdst + dst.Length; - Vector256 scaleVector256 = Avx.SetAllVector256(scale); - - while (pDstCurrent + 8 <= pEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -611,7 +593,7 @@ public static unsafe void AddScaleU(float scale, Span src, Span ds Vector128 scaleVector128 = Sse.SetAllVector128(scale); - if (pDstCurrent + 4 <= pEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -624,17 +606,9 @@ public static unsafe void AddScaleU(float scale, Span src, Span ds pDstCurrent += 4; } - while (pDstCurrent < pEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - - srcVector = Sse.MultiplyScalar(srcVector, scaleVector128); - dstVector = Sse.AddScalar(dstVector, srcVector); - Sse.StoreScalar(pDstCurrent, dstVector); - - pSrcCurrent++; - pDstCurrent++; + pDstCurrent[i] += scale * pSrcCurrent[i]; } } } @@ -645,14 +619,14 @@ public static unsafe void AddScaleCopyU(float scale, Span src, Span scaleVector256 = Avx.SetAllVector256(scale); + + int count = Math.DivRem(dst.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDstCurrent = pdst; float* pResCurrent = pres; - Vector256 scaleVector256 = Avx.SetAllVector256(scale); - - while (pResCurrent + 8 <= pResEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -667,7 +641,7 @@ public static unsafe void AddScaleCopyU(float scale, Span src, Span scaleVector128 = Sse.SetAllVector128(scale); - if (pResCurrent + 4 <= pResEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -680,17 +654,9 @@ public static unsafe void AddScaleCopyU(float scale, Span src, Span srcVector = Sse.LoadScalarVector128(pSrcCurrent); - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - srcVector = Sse.MultiplyScalar(srcVector, scaleVector128); - dstVector = Sse.AddScalar(dstVector, srcVector); - Sse.StoreScalar(pResCurrent, dstVector); - - pSrcCurrent++; - pDstCurrent++; - pResCurrent++; + pResCurrent[i] = pDstCurrent[i] + scale * pSrcCurrent[i]; } } } @@ -701,14 +667,14 @@ public static unsafe void AddScaleSU(float scale, Span src, Span idx fixed (int* pidx = idx) fixed (float* pdst = dst) { + Vector256 scaleVector256 = Avx.SetAllVector256(scale); + + int count = Math.DivRem(idx.Length, 8, out int remainder); float* pSrcCurrent = psrc; int* pIdxCurrent = pidx; float* pDstCurrent = pdst; - int* pEnd = pidx + idx.Length; - - Vector256 scaleVector256 = Avx.SetAllVector256(scale); - while (pIdxCurrent + 8 <= pEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); Vector256 dstVector = Load8(pDstCurrent, pIdxCurrent); @@ -723,7 +689,7 @@ public static unsafe void AddScaleSU(float scale, Span src, Span idx Vector128 scaleVector128 = Sse.SetAllVector128(scale); - if (pIdxCurrent + 4 <= pEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = SseIntrinsics.Load4(pDstCurrent, pIdxCurrent); @@ -736,12 +702,10 @@ public static unsafe void AddScaleSU(float scale, Span src, Span idx pSrcCurrent += 4; } - while (pIdxCurrent < pEnd) + for (int i = 0; i < remainder % 4; i++) { - pDstCurrent[*pIdxCurrent] += scale * (*pSrcCurrent); - - pIdxCurrent++; - pSrcCurrent++; + int index = pIdxCurrent[i]; + pDstCurrent[index] += scale * pSrcCurrent[i]; } } } @@ -751,11 +715,11 @@ public static unsafe void AddU(Span src, Span dst) fixed (float* psrc = src) fixed (float* pdst = dst) { + int count = Math.DivRem(dst.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDstCurrent = pdst; - float* pEnd = psrc + src.Length; - while (pSrcCurrent + 8 <= pEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -767,7 +731,7 @@ public static unsafe void AddU(Span src, Span dst) pDstCurrent += 8; } - if (pSrcCurrent + 4 <= pEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -779,16 +743,9 @@ public static unsafe void AddU(Span src, Span dst) pDstCurrent += 4; } - while (pSrcCurrent < pEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - - Vector128 result = Sse.AddScalar(srcVector, dstVector); - Sse.StoreScalar(pDstCurrent, result); - - pSrcCurrent++; - pDstCurrent++; + pDstCurrent[i] += pSrcCurrent[i]; } } } @@ -799,12 +756,12 @@ public static unsafe void AddSU(Span src, Span idx, Span dst) fixed (int* pidx = idx) fixed (float* pdst = dst) { + int count = Math.DivRem(idx.Length, 8, out int remainder); float* pSrcCurrent = psrc; int* pIdxCurrent = pidx; float* pDstCurrent = pdst; - int* pEnd = pidx + idx.Length; - while (pIdxCurrent + 8 <= pEnd) + for (int i = 0; i < count; i++) { Vector256 dstVector = Load8(pDstCurrent, pIdxCurrent); Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); @@ -816,7 +773,7 @@ public static unsafe void AddSU(Span src, Span idx, Span dst) pSrcCurrent += 8; } - if (pIdxCurrent + 4 <= pEnd) + if (remainder >= 4) { Vector128 dstVector = SseIntrinsics.Load4(pDstCurrent, pIdxCurrent); Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); @@ -828,12 +785,10 @@ public static unsafe void AddSU(Span src, Span idx, Span dst) pSrcCurrent += 4; } - while (pIdxCurrent < pEnd) + for (int i = 0; i < remainder % 4; i++) { - pDstCurrent[*pIdxCurrent] += *pSrcCurrent; - - pIdxCurrent++; - pSrcCurrent++; + int index = pIdxCurrent[i]; + pDstCurrent[index] += pSrcCurrent[i]; } } } @@ -844,12 +799,12 @@ public static unsafe void MulElementWiseU(Span src1, Span src2, Sp fixed (float* psrc2 = src2) fixed (float* pdst = dst) { + int count = Math.DivRem(dst.Length, 8, out int remainder); float* pSrc1Current = psrc1; float* pSrc2Current = psrc2; float* pDstCurrent = pdst; - float* pEnd = pdst + dst.Length; - while (pDstCurrent + 8 <= pEnd) + for (int i = 0; i < count; i++) { Vector256 src1Vector = Avx.LoadVector256(pSrc1Current); Vector256 src2Vector = Avx.LoadVector256(pSrc2Current); @@ -861,7 +816,7 @@ public static unsafe void MulElementWiseU(Span src1, Span src2, Sp pDstCurrent += 8; } - if (pDstCurrent + 4 <= pEnd) + if (remainder >= 4) { Vector128 src1Vector = Sse.LoadVector128(pSrc1Current); Vector128 src2Vector = Sse.LoadVector128(pSrc2Current); @@ -873,16 +828,9 @@ public static unsafe void MulElementWiseU(Span src1, Span src2, Sp pDstCurrent += 4; } - while (pDstCurrent < pEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 src1Vector = Sse.LoadScalarVector128(pSrc1Current); - Vector128 src2Vector = Sse.LoadScalarVector128(pSrc2Current); - src2Vector = Sse.MultiplyScalar(src1Vector, src2Vector); - Sse.StoreScalar(pDstCurrent, src2Vector); - - pSrc1Current++; - pSrc2Current++; - pDstCurrent++; + pDstCurrent[i] = pSrc1Current[i] * pSrc2Current[i]; } } } @@ -891,12 +839,12 @@ public static unsafe float SumU(Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { result256 = Avx.Add(result256, Avx.LoadVector256(pSrcCurrent)); pSrcCurrent += 8; @@ -907,21 +855,21 @@ public static unsafe float SumU(Span src) Vector128 result128 = Sse.SetZeroVector128(); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { result128 = Sse.Add(result128, Sse.LoadVector128(pSrcCurrent)); pSrcCurrent += 4; } result128 = SseIntrinsics.VectorSum128(in result128); + float result = Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { - result128 = Sse.AddScalar(result128, Sse.LoadScalarVector128(pSrcCurrent)); - pSrcCurrent++; + result += pSrcCurrent[i]; } - return Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); + return result; } } @@ -929,12 +877,12 @@ public static unsafe float SumSqU(Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); result256 = Avx.Add(result256, Avx.Multiply(srcVector, srcVector)); @@ -947,7 +895,7 @@ public static unsafe float SumSqU(Span src) Vector128 result128 = Sse.SetZeroVector128(); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); result128 = Sse.Add(result128, Sse.Multiply(srcVector, srcVector)); @@ -956,16 +904,14 @@ public static unsafe float SumSqU(Span src) } result128 = SseIntrinsics.VectorSum128(in result128); + float result = Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); - result128 = Sse.AddScalar(result128, Sse.MultiplyScalar(srcVector, srcVector)); - - pSrcCurrent++; + result += pSrcCurrent[i] * pSrcCurrent[i]; } - return Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); + return result; } } @@ -973,13 +919,13 @@ public static unsafe float SumSqDiffU(float mean, Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); Vector256 meanVector256 = Avx.SetAllVector256(mean); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Subtract(srcVector, meanVector256); @@ -994,7 +940,7 @@ public static unsafe float SumSqDiffU(float mean, Span src) Vector128 result128 = Sse.SetZeroVector128(); Vector128 meanVector128 = Sse.SetAllVector128(mean); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Subtract(srcVector, meanVector128); @@ -1004,17 +950,15 @@ public static unsafe float SumSqDiffU(float mean, Span src) } result128 = SseIntrinsics.VectorSum128(in result128); + float result = Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); - srcVector = Sse.SubtractScalar(srcVector, meanVector128); - result128 = Sse.AddScalar(result128, Sse.MultiplyScalar(srcVector, srcVector)); - - pSrcCurrent++; + float difference = pSrcCurrent[i] - mean; + result += difference * difference; } - return Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); + return result; } } @@ -1022,12 +966,12 @@ public static unsafe float SumAbsU(Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); result256 = Avx.Add(result256, Avx.And(srcVector, _absMask256)); @@ -1040,7 +984,7 @@ public static unsafe float SumAbsU(Span src) Vector128 result128 = Sse.SetZeroVector128(); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); result128 = Sse.Add(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128)); @@ -1050,7 +994,7 @@ public static unsafe float SumAbsU(Span src) result128 = SseIntrinsics.VectorSum128(in result128); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); result128 = Sse.AddScalar(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128)); @@ -1066,13 +1010,13 @@ public static unsafe float SumAbsDiffU(float mean, Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); Vector256 meanVector256 = Avx.SetAllVector256(mean); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Subtract(srcVector, meanVector256); @@ -1087,7 +1031,7 @@ public static unsafe float SumAbsDiffU(float mean, Span src) Vector128 result128 = Sse.SetZeroVector128(); Vector128 meanVector128 = Sse.SetAllVector128(mean); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Subtract(srcVector, meanVector128); @@ -1098,7 +1042,7 @@ public static unsafe float SumAbsDiffU(float mean, Span src) result128 = SseIntrinsics.VectorSum128(in result128); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); srcVector = Sse.SubtractScalar(srcVector, meanVector128); @@ -1115,12 +1059,12 @@ public static unsafe float MaxAbsU(Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); result256 = Avx.Max(result256, Avx.And(srcVector, _absMask256)); @@ -1133,7 +1077,7 @@ public static unsafe float MaxAbsU(Span src) Vector128 result128 = Sse.SetZeroVector128(); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); result128 = Sse.Max(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128)); @@ -1143,7 +1087,7 @@ public static unsafe float MaxAbsU(Span src) result128 = SseIntrinsics.VectorMax128(in result128); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); result128 = Sse.MaxScalar(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128)); @@ -1159,13 +1103,13 @@ public static unsafe float MaxAbsDiffU(float mean, Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); Vector256 meanVector256 = Avx.SetAllVector256(mean); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Subtract(srcVector, meanVector256); @@ -1180,7 +1124,7 @@ public static unsafe float MaxAbsDiffU(float mean, Span src) Vector128 result128 = Sse.SetZeroVector128(); Vector128 meanVector128 = Sse.SetAllVector128(mean); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Subtract(srcVector, meanVector128); @@ -1191,7 +1135,7 @@ public static unsafe float MaxAbsDiffU(float mean, Span src) result128 = SseIntrinsics.VectorMax128(in result128); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); srcVector = Sse.SubtractScalar(srcVector, meanVector128); @@ -1209,13 +1153,13 @@ public static unsafe float DotU(Span src, Span dst) fixed (float* psrc = src) fixed (float* pdst = dst) { + Vector256 result256 = Avx.SetZeroVector256(); + + int count = Math.DivRem(src.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDstCurrent = pdst; - float* pSrcEnd = psrc + src.Length; - Vector256 result256 = Avx.SetZeroVector256(); - - while (pSrcCurrent + 8 <= pSrcEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -1231,7 +1175,7 @@ public static unsafe float DotU(Span src, Span dst) Vector128 result128 = Sse.SetZeroVector128(); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -1243,19 +1187,14 @@ public static unsafe float DotU(Span src, Span dst) } result128 = SseIntrinsics.VectorSum128(in result128); + float result = Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - - result128 = Sse.AddScalar(result128, Sse.MultiplyScalar(srcVector, dstVector)); - - pSrcCurrent++; - pDstCurrent++; + result += pSrcCurrent[i] * pDstCurrent[i]; } - return Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); + return result; } } @@ -1265,14 +1204,14 @@ public static unsafe float DotSU(Span src, Span dst, Span idx fixed (float* pdst = dst) fixed (int* pidx = idx) { + Vector256 result256 = Avx.SetZeroVector256(); + + int count = Math.DivRem(idx.Length, 8, out int remainder); float* pSrcCurrent = psrc; - float* pDstCurrent = pdst; int* pIdxCurrent = pidx; - int* pIdxEnd = pidx + idx.Length; - - Vector256 result256 = Avx.SetZeroVector256(); + float* pDstCurrent = pdst; - while (pIdxCurrent + 8 <= pIdxEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Load8(pSrcCurrent, pIdxCurrent); Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -1288,7 +1227,7 @@ public static unsafe float DotSU(Span src, Span dst, Span idx Vector128 result128 = Sse.SetZeroVector128(); - if (pIdxCurrent + 4 <= pIdxEnd) + if (remainder >= 4) { Vector128 srcVector = SseIntrinsics.Load4(pSrcCurrent, pIdxCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -1300,19 +1239,15 @@ public static unsafe float DotSU(Span src, Span dst, Span idx } result128 = SseIntrinsics.VectorSum128(in result128); + float result = Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); - while (pIdxCurrent < pIdxEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = SseIntrinsics.Load1(pSrcCurrent, pIdxCurrent); - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - - result128 = Sse.AddScalar(result128, Sse.MultiplyScalar(srcVector, dstVector)); - - pIdxCurrent++; - pDstCurrent++; + int index = pIdxCurrent[i]; + result += pSrcCurrent[index] * pDstCurrent[i]; } - return Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); + return result; } } @@ -1321,13 +1256,13 @@ public static unsafe float Dist2(Span src, Span dst) fixed (float* psrc = src) fixed (float* pdst = dst) { + Vector256 sqDistanceVector256 = Avx.SetZeroVector256(); + + int count = Math.DivRem(src.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDstCurrent = pdst; - float* pSrcEnd = psrc + src.Length; - - Vector256 sqDistanceVector256 = Avx.SetZeroVector256(); - while (pSrcCurrent + 8 <= pSrcEnd) + for (int i = 0; i < count; i++) { Vector256 distanceVector = Avx.Subtract(Avx.LoadVector256(pSrcCurrent), Avx.LoadVector256(pDstCurrent)); @@ -1343,7 +1278,7 @@ public static unsafe float Dist2(Span src, Span dst) Vector128 sqDistanceVector128 = Sse.SetZeroVector128(); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 distanceVector = Sse.Subtract(Sse.LoadVector128(pSrcCurrent), Sse.LoadVector128(pDstCurrent)); @@ -1355,15 +1290,12 @@ public static unsafe float Dist2(Span src, Span dst) } sqDistanceVector128 = SseIntrinsics.VectorSum128(in sqDistanceVector128); - float norm = Sse.ConvertToSingle(Sse.AddScalar(sqDistanceVector128, sqDistanceVectorPadded)); - while (pSrcCurrent < pSrcEnd) + + for (int i = 0; i < remainder % 4; i++) { - float distance = (*pSrcCurrent) - (*pDstCurrent); + float distance = pSrcCurrent[i] - pDstCurrent[i]; norm += distance * distance; - - pSrcCurrent++; - pDstCurrent++; } return norm; @@ -1376,15 +1308,15 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span src, flo fixed (float* pdst1 = v) fixed (float* pdst2 = w) { - float* pSrcEnd = psrc + src.Length; + Vector256 xPrimal256 = Avx.SetAllVector256(primalUpdate); + Vector256 xThreshold256 = Avx.SetAllVector256(threshold); + + int count = Math.DivRem(src.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDst1Current = pdst1; float* pDst2Current = pdst2; - Vector256 xPrimal256 = Avx.SetAllVector256(primalUpdate); - Vector256 xThreshold256 = Avx.SetAllVector256(threshold); - - while (pSrcCurrent + 8 <= pSrcEnd) + for (int i = 0; i < count; i++) { Vector256 xSrc = Avx.LoadVector256(pSrcCurrent); @@ -1403,7 +1335,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span src, flo Vector128 xPrimal128 = Sse.SetAllVector128(primalUpdate); Vector128 xThreshold128 = Sse.SetAllVector128(threshold); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 xSrc = Sse.LoadVector128(pSrcCurrent); @@ -1419,15 +1351,11 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span src, flo pDst2Current += 4; } - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { - *pDst1Current += (*pSrcCurrent) * primalUpdate; - float dst1 = *pDst1Current; - *pDst2Current = Math.Abs(dst1) > threshold ? (dst1 > 0 ? dst1 - threshold : dst1 + threshold) : 0; - - pSrcCurrent++; - pDst1Current++; - pDst2Current++; + pDst1Current[i] += primalUpdate * pSrcCurrent[i]; + float dst1 = pDst1Current[i]; + pDst2Current[i] = Math.Abs(dst1) > threshold ? (dst1 > 0 ? dst1 - threshold : dst1 + threshold) : 0; } } } @@ -1439,14 +1367,14 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span src, Sp fixed (float* pdst1 = v) fixed (float* pdst2 = w) { - int* pIdxEnd = pidx + indices.Length; - float* pSrcCurrent = psrc; - int* pIdxCurrent = pidx; - Vector256 xPrimal256 = Avx.SetAllVector256(primalUpdate); Vector256 xThreshold = Avx.SetAllVector256(threshold); - while (pIdxCurrent + 8 <= pIdxEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + int* pIdxCurrent = pidx; + + for (int i = 0; i < count; i++) { Vector256 xSrc = Avx.LoadVector256(pSrcCurrent); @@ -1464,7 +1392,7 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span src, Sp Vector128 xPrimal128 = Sse.SetAllVector128(primalUpdate); Vector128 xThreshold128 = Sse.SetAllVector128(threshold); - if (pIdxCurrent + 4 <= pIdxEnd) + if (remainder >= 4) { Vector128 xSrc = Sse.LoadVector128(pSrcCurrent); @@ -1479,15 +1407,12 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span src, Sp pSrcCurrent += 4; } - while (pIdxCurrent < pIdxEnd) + for (int i = 0; i < remainder % 4; i++) { - int index = *pIdxCurrent; - pdst1[index] += (*pSrcCurrent) * primalUpdate; + int index = pIdxCurrent[i]; + pdst1[index] += primalUpdate * pSrcCurrent[i]; float dst1 = pdst1[index]; pdst2[index] = Math.Abs(dst1) > threshold ? (dst1 > 0 ? dst1 - threshold : dst1 + threshold) : 0; - - pIdxCurrent++; - pSrcCurrent++; } } } diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index 0f4fb54d18..76e6ed52fc 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -412,12 +412,11 @@ public static unsafe void AddScalarU(float scalar, Span dst) { fixed (float* pdst = dst) { - float* pDstEnd = pdst + dst.Length; - float* pDstCurrent = pdst; - Vector128 scalarVector = Sse.SetAllVector128(scalar); + int count = Math.DivRem(dst.Length, 4, out int remainder); + float* pDstCurrent = pdst; - while (pDstCurrent + 4 <= pDstEnd) + for (int i = 0; i < count; i++) { Vector128 dstVector = Sse.LoadVector128(pDstCurrent); dstVector = Sse.Add(dstVector, scalarVector); @@ -426,13 +425,9 @@ public static unsafe void AddScalarU(float scalar, Span dst) pDstCurrent += 4; } - while (pDstCurrent < pDstEnd) + for (int i = 0; i < remainder; i++) { - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - dstVector = Sse.AddScalar(dstVector, scalarVector); - Sse.StoreScalar(pDstCurrent, dstVector); - - pDstCurrent++; + pDstCurrent[i] += scalar; } } }