diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IIndexOfOperator.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IIndexOfOperator.cs index 09492ac64eb075..8e2f64f979a852 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IIndexOfOperator.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IIndexOfOperator.cs @@ -9,124 +9,721 @@ namespace System.Numerics.Tensors { public static unsafe partial class TensorPrimitives { - private interface IIndexOfOperator + private interface IIndexOfMinMaxOperator { - static abstract int Invoke(ref T result, T current, int resultIndex, int currentIndex); - static abstract void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 currentIndex); - static abstract void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 currentIndex); - static abstract void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 currentIndex); + static abstract bool IsQuickReturn(T value); + static abstract Vector128 IsQuickReturn(Vector128 value); + static abstract Vector256 IsQuickReturn(Vector256 value); + static abstract Vector512 IsQuickReturn(Vector512 value); + static abstract T Aggregate(Vector128 value); + static abstract T Aggregate(Vector256 value); + static abstract T Aggregate(Vector512 value); + /// Returns true if x precedes y. + static abstract bool Compare(T x, T y); + static abstract Vector128 Compare(Vector128 x, Vector128 y); + static abstract Vector256 Compare(Vector256 x, Vector256 y); + static abstract Vector512 Compare(Vector512 x, Vector512 y); } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int IndexOfFinalAggregate(Vector128 result, Vector128 resultIndex) - where TIndexOfOperator : struct, IIndexOfOperator + private static int IndexOfMinMaxCore(ReadOnlySpan x) + where T : INumber where TOperator : struct, IIndexOfMinMaxOperator { - Vector128 tmpResult; - Vector128 tmpIndex; + if (x.IsEmpty) + { + return -1; + } + + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && x.Length >= Vector512.Count) + { + return sizeof(T) == 8 ? IndexOfMinMaxVector512Size4Plus(x) : + sizeof(T) == 4 ? IndexOfMinMaxVector512Size4Plus(x) : + sizeof(T) == 2 ? IndexOfMinMaxVector512Size2(x) : + IndexOfMinMaxVector512Size1(x); + } + + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && x.Length >= Vector256.Count) + { + return sizeof(T) == 8 ? IndexOfMinMaxVector256Size4Plus(x) : + sizeof(T) == 4 ? IndexOfMinMaxVector256Size4Plus(x) : + sizeof(T) == 2 ? IndexOfMinMaxVector256Size2(x) : + IndexOfMinMaxVector256Size1(x); + } + + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && x.Length >= Vector128.Count) + { + return sizeof(T) == 8 ? IndexOfMinMaxVector128Size4Plus(x) : + sizeof(T) == 4 ? IndexOfMinMaxVector128Size4Plus(x) : + sizeof(T) == 2 ? IndexOfMinMaxVector128Size2(x) : + IndexOfMinMaxVector128Size1(x); + } + + return IndexOfMinMaxNaive(x); + } + + private static int IndexOfMinMaxNaive(ReadOnlySpan x) + where T : INumber where TOperator : struct, IIndexOfMinMaxOperator + { + T result = x[0]; + int resultIndex = 0; + if (TOperator.IsQuickReturn(result)) + { + return resultIndex; + } + + for (int i = 1; i < x.Length; i++) + { + T current = x[i]; + if (TOperator.IsQuickReturn(current)) + { + return i; + } + if (TOperator.Compare(current, result)) + { + result = current; + resultIndex = i; + } + } + + return resultIndex; + } + + private static int IndexOfMinMaxVector128Size4Plus(ReadOnlySpan x) + where T : INumber where TOperator : struct, IIndexOfMinMaxOperator where TInt : IBinaryInteger + { + Debug.Assert(sizeof(T) == 4 || sizeof(T) == 8); + Debug.Assert(typeof(TInt) == typeof(uint) || typeof(TInt) == typeof(ulong)); + Debug.Assert(sizeof(TInt) == sizeof(T)); + + // Initialize result by reading first vector and quick return if possible. + Vector128 result = Vector128.Create(x); + Vector128 mask = TOperator.IsQuickReturn(result); + if (mask != Vector128.Zero) + { + return IndexOfFirstMatch(mask); + } + + // Initialize indices. + Vector128 indexIncrement = Vector128.Create(TInt.CreateChecked(Vector128.Count)); + Vector128 resultIndex = Vector128.Indices; + Vector128 currentIndex = resultIndex + indexIncrement; + ReadOnlySpan span = x.Slice(Vector128.Count); + + while (!span.IsEmpty) + { + Vector128 current; + if (span.Length >= Vector128.Count) + { + current = Vector128.Create(span); + span = span.Slice(Vector128.Count); + } + else + { + // Process a final back-shifted to cover remaining elements in x in one vector. + int start = x.Length - Vector128.Count; + current = Vector128.Create(x.Slice(start)); + currentIndex = Vector128.Create(TInt.CreateChecked(start)) + Vector128.Indices; + span = ReadOnlySpan.Empty; + } + + // Quick return if possible. + mask = TOperator.IsQuickReturn(current); + if (mask != Vector128.Zero) + { + return int.CreateChecked(currentIndex.ToScalar()) + IndexOfFirstMatch(mask); + } + + // Get mask for which lanes that should have result updated. + mask = TOperator.Compare(current, result); + + // Update result and indices. + result = ElementWiseSelect(mask, current, result); + resultIndex = ElementWiseSelect(mask.As(), currentIndex, resultIndex); + currentIndex += indexIncrement; + } + + { + // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. + T aggResult = TOperator.Aggregate(result); + Vector128 aggMask = ~Vector128.Equals(result.As(), Vector128.Create(aggResult).As()); + Vector128 aggIndex = resultIndex | aggMask; + return int.CreateChecked(HorizontalAggregate>(aggIndex)); + } + } + + private static int IndexOfMinMaxVector128Size2(ReadOnlySpan x) + where T : INumber where TOperator : struct, IIndexOfMinMaxOperator + { + Debug.Assert(sizeof(T) == 2); + + // Initialize result by reading first vector and quick return if possible. + Vector128 result = Vector128.Create(x); + Vector128 mask = TOperator.IsQuickReturn(result); + if (mask != Vector128.Zero) + { + return IndexOfFirstMatch(mask); + } + + // Initialize indices. + Vector128 indexIncrement = Vector128.Create((uint)Vector128.Count); + Vector128 resultIndex1 = Vector128.Indices; + Vector128 resultIndex2 = resultIndex1 + indexIncrement; + Vector128 currentIndex = resultIndex2 + indexIncrement; + ReadOnlySpan span = x.Slice(Vector128.Count); + + while (!span.IsEmpty) + { + Vector128 current; + if (span.Length >= Vector128.Count) + { + current = Vector128.Create(span); + span = span.Slice(Vector128.Count); + } + else + { + // Process a final back-shifted to cover remaining elements in x in one vector. + int start = x.Length - Vector128.Count; + current = Vector128.Create(x.Slice(start)); + currentIndex = Vector128.Create((uint)start) + Vector128.Indices; + span = ReadOnlySpan.Empty; + } + + // Quick return if possible. + mask = TOperator.IsQuickReturn(current); + if (mask != Vector128.Zero) + { + return (int)currentIndex.ToScalar() + IndexOfFirstMatch(mask); + } + + // Get mask for which lanes that should have result updated, also widen it for updating the indices. + mask = TOperator.Compare(current, result); + (Vector128 mask1, Vector128 mask2) = Vector128.Widen(mask.AsInt16()); + + // Update result and indices. + result = ElementWiseSelect(mask, current, result); + resultIndex1 = ElementWiseSelect(mask1.AsUInt32(), currentIndex, resultIndex1); + currentIndex += indexIncrement; + resultIndex2 = ElementWiseSelect(mask2.AsUInt32(), currentIndex, resultIndex2); + currentIndex += indexIncrement; + } + + { + // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. + T aggResult = TOperator.Aggregate(result); + Vector128 aggMask = ~Vector128.Equals(result.AsInt16(), Vector128.Create(aggResult).AsInt16()); + + (Vector128 mask1, Vector128 mask2) = Vector128.Widen(aggMask); + Vector128 aggIndex = resultIndex1 | mask1.AsUInt32(); + aggIndex = MinOperator.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); + + return (int)HorizontalAggregate>(aggIndex); + } + } + + private static int IndexOfMinMaxVector128Size1(ReadOnlySpan x) + where T : INumber where TOperator : struct, IIndexOfMinMaxOperator + { + Debug.Assert(sizeof(T) == 1); + + // Initialize result by reading first vector and quick return if possible. + Vector128 result = Vector128.Create(x); + Vector128 mask = TOperator.IsQuickReturn(result); + if (mask != Vector128.Zero) + { + return IndexOfFirstMatch(mask); + } + + // Initialize indices. + Vector128 indexIncrement = Vector128.Create((uint)Vector128.Count); + Vector128 resultIndex1 = Vector128.Indices; + Vector128 resultIndex2 = resultIndex1 + indexIncrement; + Vector128 resultIndex3 = resultIndex2 + indexIncrement; + Vector128 resultIndex4 = resultIndex3 + indexIncrement; + Vector128 currentIndex = resultIndex4 + indexIncrement; + ReadOnlySpan span = x.Slice(Vector128.Count); + + while (!span.IsEmpty) + { + Vector128 current; + if (span.Length >= Vector128.Count) + { + current = Vector128.Create(span); + span = span.Slice(Vector128.Count); + } + else + { + // Process a final back-shifted to cover remaining elements in x in one vector. + int start = x.Length - Vector128.Count; + current = Vector128.Create(x.Slice(start)); + currentIndex = Vector128.Create((uint)start) + Vector128.Indices; + span = ReadOnlySpan.Empty; + } + + // Quick return if possible. + mask = TOperator.IsQuickReturn(current); + if (mask != Vector128.Zero) + { + return (int)currentIndex.ToScalar() + IndexOfFirstMatch(mask); + } + + // Get mask for which lanes that should have result updated, also widen it for updating the indices. + mask = TOperator.Compare(current, result); + (Vector128 lowerMask, Vector128 upperMask) = Vector128.Widen(mask.AsSByte()); + (Vector128 mask1, Vector128 mask2) = Vector128.Widen(lowerMask); + (Vector128 mask3, Vector128 mask4) = Vector128.Widen(upperMask); + + // Update result and indices. + result = ElementWiseSelect(mask, current, result); + resultIndex1 = ElementWiseSelect(mask1.AsUInt32(), currentIndex, resultIndex1); + currentIndex += indexIncrement; + resultIndex2 = ElementWiseSelect(mask2.AsUInt32(), currentIndex, resultIndex2); + currentIndex += indexIncrement; + resultIndex3 = ElementWiseSelect(mask3.AsUInt32(), currentIndex, resultIndex3); + currentIndex += indexIncrement; + resultIndex4 = ElementWiseSelect(mask4.AsUInt32(), currentIndex, resultIndex4); + currentIndex += indexIncrement; + } + + { + // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. + T aggResult = TOperator.Aggregate(result); + Vector128 aggMask = ~Vector128.Equals(result.AsSByte(), Vector128.Create(aggResult).AsSByte()); + + (Vector128 lowerMask, Vector128 upperMask) = Vector128.Widen(aggMask); + (Vector128 mask1, Vector128 mask2) = Vector128.Widen(lowerMask); + (Vector128 mask3, Vector128 mask4) = Vector128.Widen(upperMask); + Vector128 aggIndex = resultIndex1 | mask1.AsUInt32(); + aggIndex = MinOperator.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); + aggIndex = MinOperator.Invoke(aggIndex, resultIndex3 | mask3.AsUInt32()); + aggIndex = MinOperator.Invoke(aggIndex, resultIndex4 | mask4.AsUInt32()); + + return (int)HorizontalAggregate>(aggIndex); + } + } + + private static int IndexOfMinMaxVector256Size4Plus(ReadOnlySpan x) + where T : INumber where TOperator : struct, IIndexOfMinMaxOperator where TInt : IBinaryInteger + { + Debug.Assert(sizeof(T) == 4 || sizeof(T) == 8); + Debug.Assert(typeof(TInt) == typeof(uint) || typeof(TInt) == typeof(ulong)); + Debug.Assert(sizeof(TInt) == sizeof(T)); + + // Initialize result by reading first vector and quick return if possible. + Vector256 result = Vector256.Create(x); + Vector256 mask = TOperator.IsQuickReturn(result); + if (mask != Vector256.Zero) + { + return IndexOfFirstMatch(mask); + } + + // Initialize indices. + Vector256 indexIncrement = Vector256.Create(TInt.CreateChecked(Vector256.Count)); + Vector256 resultIndex = Vector256.Indices; + Vector256 currentIndex = resultIndex + indexIncrement; + ReadOnlySpan span = x.Slice(Vector256.Count); - if (sizeof(T) == 8) + while (!span.IsEmpty) { - // Compare 0 with 1 - tmpResult = Vector128.Shuffle(result.AsInt64(), Vector128.Create(1, 0)).As(); - tmpIndex = Vector128.Shuffle(resultIndex.AsInt64(), Vector128.Create(1, 0)).As(); - TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + Vector256 current; + if (span.Length >= Vector256.Count) + { + current = Vector256.Create(span); + span = span.Slice(Vector256.Count); + } + else + { + // Process a final back-shifted to cover remaining elements in x in one vector. + int start = x.Length - Vector256.Count; + current = Vector256.Create(x.Slice(start)); + currentIndex = Vector256.Create(TInt.CreateChecked(start)) + Vector256.Indices; + span = ReadOnlySpan.Empty; + } + + // Quick return if possible. + mask = TOperator.IsQuickReturn(current); + if (mask != Vector256.Zero) + { + return int.CreateChecked(currentIndex.ToScalar()) + IndexOfFirstMatch(mask); + } - // Return 0 - return (int)resultIndex.As().ToScalar(); + // Get mask for which lanes that should have result updated. + mask = TOperator.Compare(current, result); + + // Update result and indices. + result = ElementWiseSelect(mask, current, result); + resultIndex = ElementWiseSelect(mask.As(), currentIndex, resultIndex); + currentIndex += indexIncrement; } - if (sizeof(T) == 4) { - // Compare 0,1 with 2,3 - tmpResult = Vector128.Shuffle(result.AsInt32(), Vector128.Create(2, 3, 0, 1)).As(); - tmpIndex = Vector128.Shuffle(resultIndex.AsInt32(), Vector128.Create(2, 3, 0, 1)).As(); - TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. + T aggResult = TOperator.Aggregate(result); + Vector256 aggMask = ~Vector256.Equals(result.As(), Vector256.Create(aggResult).As()); + Vector256 aggIndex = resultIndex | aggMask; + return int.CreateChecked(HorizontalAggregate>(aggIndex)); + } + } - // Compare 0 with 1 - tmpResult = Vector128.Shuffle(result.AsInt32(), Vector128.Create(1, 0, 3, 2)).As(); - tmpIndex = Vector128.Shuffle(resultIndex.AsInt32(), Vector128.Create(1, 0, 3, 2)).As(); - TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + private static int IndexOfMinMaxVector256Size2(ReadOnlySpan x) + where T : INumber where TOperator : struct, IIndexOfMinMaxOperator + { + Debug.Assert(sizeof(T) == 2); - // Return 0 - return resultIndex.As().ToScalar(); + // Initialize result by reading first vector and quick return if possible. + Vector256 result = Vector256.Create(x); + Vector256 mask = TOperator.IsQuickReturn(result); + if (mask != Vector256.Zero) + { + return IndexOfFirstMatch(mask); } - if (sizeof(T) == 2) + // Initialize indices. + Vector256 indexIncrement = Vector256.Create((uint)Vector256.Count); + Vector256 resultIndex1 = Vector256.Indices; + Vector256 resultIndex2 = resultIndex1 + indexIncrement; + Vector256 currentIndex = resultIndex2 + indexIncrement; + ReadOnlySpan span = x.Slice(Vector256.Count); + + while (!span.IsEmpty) { - // Compare 0,1,2,3 with 4,5,6,7 - tmpResult = Vector128.Shuffle(result.AsInt16(), Vector128.Create(4, 5, 6, 7, 0, 1, 2, 3)).As(); - tmpIndex = Vector128.Shuffle(resultIndex.AsInt16(), Vector128.Create(4, 5, 6, 7, 0, 1, 2, 3)).As(); - TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + Vector256 current; + if (span.Length >= Vector256.Count) + { + current = Vector256.Create(span); + span = span.Slice(Vector256.Count); + } + else + { + // Process a final back-shifted to cover remaining elements in x in one vector. + int start = x.Length - Vector256.Count; + current = Vector256.Create(x.Slice(start)); + currentIndex = Vector256.Create((uint)start) + Vector256.Indices; + span = ReadOnlySpan.Empty; + } - // Compare 0,1 with 2,3 - tmpResult = Vector128.Shuffle(result.AsInt16(), Vector128.Create(2, 3, 0, 1, 4, 5, 6, 7)).As(); - tmpIndex = Vector128.Shuffle(resultIndex.AsInt16(), Vector128.Create(2, 3, 0, 1, 4, 5, 6, 7)).As(); - TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + // Quick return if possible. + mask = TOperator.IsQuickReturn(current); + if (mask != Vector256.Zero) + { + return (int)currentIndex.ToScalar() + IndexOfFirstMatch(mask); + } - // Compare 0 with 1 - tmpResult = Vector128.Shuffle(result.AsInt16(), Vector128.Create(1, 0, 2, 3, 4, 5, 6, 7)).As(); - tmpIndex = Vector128.Shuffle(resultIndex.AsInt16(), Vector128.Create(1, 0, 2, 3, 4, 5, 6, 7)).As(); - TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + // Get mask for which lanes that should have result updated, also widen it for updating the indices. + mask = TOperator.Compare(current, result); + (Vector256 mask1, Vector256 mask2) = Vector256.Widen(mask.AsInt16()); - // Return 0 - return resultIndex.As().ToScalar(); + // Update result and indices. + result = ElementWiseSelect(mask, current, result); + resultIndex1 = ElementWiseSelect(mask1.AsUInt32(), currentIndex, resultIndex1); + currentIndex += indexIncrement; + resultIndex2 = ElementWiseSelect(mask2.AsUInt32(), currentIndex, resultIndex2); + currentIndex += indexIncrement; } + { + // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. + T aggResult = TOperator.Aggregate(result); + Vector256 aggMask = ~Vector256.Equals(result.AsInt16(), Vector256.Create(aggResult).AsInt16()); + + (Vector256 mask1, Vector256 mask2) = Vector256.Widen(aggMask); + Vector256 aggIndex = resultIndex1 | mask1.AsUInt32(); + aggIndex = MinOperator.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); + + return (int)HorizontalAggregate>(aggIndex); + } + } + + private static int IndexOfMinMaxVector256Size1(ReadOnlySpan x) + where T : INumber where TOperator : struct, IIndexOfMinMaxOperator + { Debug.Assert(sizeof(T) == 1); + + // Initialize result by reading first vector and quick return if possible. + Vector256 result = Vector256.Create(x); + Vector256 mask = TOperator.IsQuickReturn(result); + if (mask != Vector256.Zero) { - // Compare 0,1,2,3,4,5,6,7 with 8,9,10,11,12,13,14,15 - tmpResult = Vector128.Shuffle(result.AsByte(), Vector128.Create((byte)8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7)).As(); - tmpIndex = Vector128.Shuffle(resultIndex.AsByte(), Vector128.Create((byte)8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7)).As(); - TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + return IndexOfFirstMatch(mask); + } + + // Initialize indices. + Vector256 indexIncrement = Vector256.Create((uint)Vector256.Count); + Vector256 resultIndex1 = Vector256.Indices; + Vector256 resultIndex2 = resultIndex1 + indexIncrement; + Vector256 resultIndex3 = resultIndex2 + indexIncrement; + Vector256 resultIndex4 = resultIndex3 + indexIncrement; + Vector256 currentIndex = resultIndex4 + indexIncrement; + ReadOnlySpan span = x.Slice(Vector256.Count); + + while (!span.IsEmpty) + { + Vector256 current; + if (span.Length >= Vector256.Count) + { + current = Vector256.Create(span); + span = span.Slice(Vector256.Count); + } + else + { + // Process a final back-shifted to cover remaining elements in x in one vector. + int start = x.Length - Vector256.Count; + current = Vector256.Create(x.Slice(start)); + currentIndex = Vector256.Create((uint)start) + Vector256.Indices; + span = ReadOnlySpan.Empty; + } + + // Quick return if possible. + mask = TOperator.IsQuickReturn(current); + if (mask != Vector256.Zero) + { + return (int)currentIndex.ToScalar() + IndexOfFirstMatch(mask); + } + + // Get mask for which lanes that should have result updated, also widen it for updating the indices. + mask = TOperator.Compare(current, result); + (Vector256 lowerMask, Vector256 upperMask) = Vector256.Widen(mask.AsSByte()); + (Vector256 mask1, Vector256 mask2) = Vector256.Widen(lowerMask); + (Vector256 mask3, Vector256 mask4) = Vector256.Widen(upperMask); - // Compare 0,1,2,3 with 4,5,6,7 - tmpResult = Vector128.Shuffle(result.AsByte(), Vector128.Create((byte)4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15)).As(); - tmpIndex = Vector128.Shuffle(resultIndex.AsByte(), Vector128.Create((byte)4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15)).As(); - TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + // Update result and indices. + result = ElementWiseSelect(mask, current, result); + resultIndex1 = ElementWiseSelect(mask1.AsUInt32(), currentIndex, resultIndex1); + currentIndex += indexIncrement; + resultIndex2 = ElementWiseSelect(mask2.AsUInt32(), currentIndex, resultIndex2); + currentIndex += indexIncrement; + resultIndex3 = ElementWiseSelect(mask3.AsUInt32(), currentIndex, resultIndex3); + currentIndex += indexIncrement; + resultIndex4 = ElementWiseSelect(mask4.AsUInt32(), currentIndex, resultIndex4); + currentIndex += indexIncrement; + } - // Compare 0,1 with 2,3 - tmpResult = Vector128.Shuffle(result.AsByte(), Vector128.Create((byte)2, 3, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As(); - tmpIndex = Vector128.Shuffle(resultIndex.AsByte(), Vector128.Create((byte)2, 3, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As(); - TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + { + // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. + T aggResult = TOperator.Aggregate(result); + Vector256 aggMask = ~Vector256.Equals(result.AsSByte(), Vector256.Create(aggResult).AsSByte()); - // Compare 0 with 1 - tmpResult = Vector128.Shuffle(result.AsByte(), Vector128.Create((byte)1, 0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As(); - tmpIndex = Vector128.Shuffle(resultIndex.AsByte(), Vector128.Create((byte)1, 0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As(); - TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + (Vector256 lowerMask, Vector256 upperMask) = Vector256.Widen(aggMask); + (Vector256 mask1, Vector256 mask2) = Vector256.Widen(lowerMask); + (Vector256 mask3, Vector256 mask4) = Vector256.Widen(upperMask); + Vector256 aggIndex = resultIndex1 | mask1.AsUInt32(); + aggIndex = MinOperator.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); + aggIndex = MinOperator.Invoke(aggIndex, resultIndex3 | mask3.AsUInt32()); + aggIndex = MinOperator.Invoke(aggIndex, resultIndex4 | mask4.AsUInt32()); - // Return 0 - return resultIndex.As().ToScalar(); + return (int)HorizontalAggregate>(aggIndex); } } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int IndexOfFinalAggregate(Vector256 result, Vector256 resultIndex) - where TIndexOfOperator : struct, IIndexOfOperator + private static int IndexOfMinMaxVector512Size4Plus(ReadOnlySpan x) + where T : INumber where TOperator : struct, IIndexOfMinMaxOperator where TInt : IBinaryInteger { - // Min the upper/lower halves of the Vector256 - Vector128 resultLower = result.GetLower(); - Vector128 indexLower = resultIndex.GetLower(); + Debug.Assert(sizeof(T) == 4 || sizeof(T) == 8); + Debug.Assert(typeof(TInt) == typeof(uint) || typeof(TInt) == typeof(ulong)); + Debug.Assert(sizeof(TInt) == sizeof(T)); + + // Initialize result by reading first vector and quick return if possible. + Vector512 result = Vector512.Create(x); + Vector512 mask = TOperator.IsQuickReturn(result); + if (mask != Vector512.Zero) + { + return IndexOfFirstMatch(mask); + } + + // Initialize indices. + Vector512 indexIncrement = Vector512.Create(TInt.CreateChecked(Vector512.Count)); + Vector512 resultIndex = Vector512.Indices; + Vector512 currentIndex = resultIndex + indexIncrement; + ReadOnlySpan span = x.Slice(Vector512.Count); + + while (!span.IsEmpty) + { + Vector512 current; + if (span.Length >= Vector512.Count) + { + current = Vector512.Create(span); + span = span.Slice(Vector512.Count); + } + else + { + // Process a final back-shifted to cover remaining elements in x in one vector. + int start = x.Length - Vector512.Count; + current = Vector512.Create(x.Slice(start)); + currentIndex = Vector512.Create(TInt.CreateChecked(start)) + Vector512.Indices; + span = ReadOnlySpan.Empty; + } + + // Quick return if possible. + mask = TOperator.IsQuickReturn(current); + if (mask != Vector512.Zero) + { + return int.CreateChecked(currentIndex.ToScalar()) + IndexOfFirstMatch(mask); + } - TIndexOfOperator.Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); - return IndexOfFinalAggregate(resultLower, indexLower); + // Get mask for which lanes that should have result updated. + mask = TOperator.Compare(current, result); + + // Update result and indices. + result = ElementWiseSelect(mask, current, result); + resultIndex = ElementWiseSelect(mask.As(), currentIndex, resultIndex); + currentIndex += indexIncrement; + } + + { + // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. + T aggResult = TOperator.Aggregate(result); + Vector512 aggMask = ~Vector512.Equals(result.As(), Vector512.Create(aggResult).As()); + Vector512 aggIndex = resultIndex | aggMask; + return int.CreateChecked(HorizontalAggregate>(aggIndex)); + } } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int IndexOfFinalAggregate(Vector512 result, Vector512 resultIndex) - where TIndexOfOperator : struct, IIndexOfOperator + private static int IndexOfMinMaxVector512Size2(ReadOnlySpan x) + where T : INumber where TOperator : struct, IIndexOfMinMaxOperator { - Vector256 resultLower = result.GetLower(); - Vector256 indexLower = resultIndex.GetLower(); + Debug.Assert(sizeof(T) == 2); - TIndexOfOperator.Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); - return IndexOfFinalAggregate(resultLower, indexLower); + // Initialize result by reading first vector and quick return if possible. + Vector512 result = Vector512.Create(x); + Vector512 mask = TOperator.IsQuickReturn(result); + if (mask != Vector512.Zero) + { + return IndexOfFirstMatch(mask); + } + + // Initialize indices. + Vector512 indexIncrement = Vector512.Create((uint)Vector512.Count); + Vector512 resultIndex1 = Vector512.Indices; + Vector512 resultIndex2 = resultIndex1 + indexIncrement; + Vector512 currentIndex = resultIndex2 + indexIncrement; + ReadOnlySpan span = x.Slice(Vector512.Count); + + while (!span.IsEmpty) + { + Vector512 current; + if (span.Length >= Vector512.Count) + { + current = Vector512.Create(span); + span = span.Slice(Vector512.Count); + } + else + { + // Process a final back-shifted to cover remaining elements in x in one vector. + int start = x.Length - Vector512.Count; + current = Vector512.Create(x.Slice(start)); + currentIndex = Vector512.Create((uint)start) + Vector512.Indices; + span = ReadOnlySpan.Empty; + } + + // Quick return if possible. + mask = TOperator.IsQuickReturn(current); + if (mask != Vector512.Zero) + { + return (int)currentIndex.ToScalar() + IndexOfFirstMatch(mask); + } + + // Get mask for which lanes that should have result updated, also widen it for updating the indices. + mask = TOperator.Compare(current, result); + (Vector512 mask1, Vector512 mask2) = Vector512.Widen(mask.AsInt16()); + + // Update result and indices. + result = ElementWiseSelect(mask, current, result); + resultIndex1 = ElementWiseSelect(mask1.AsUInt32(), currentIndex, resultIndex1); + currentIndex += indexIncrement; + resultIndex2 = ElementWiseSelect(mask2.AsUInt32(), currentIndex, resultIndex2); + currentIndex += indexIncrement; + } + + { + // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. + T aggResult = TOperator.Aggregate(result); + Vector512 aggMask = ~Vector512.Equals(result.AsInt16(), Vector512.Create(aggResult).AsInt16()); + + (Vector512 mask1, Vector512 mask2) = Vector512.Widen(aggMask); + Vector512 aggIndex = resultIndex1 | mask1.AsUInt32(); + aggIndex = MinOperator.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); + + return (int)HorizontalAggregate>(aggIndex); + } } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector128 IndexLessThan(Vector128 indices1, Vector128 indices2) => - sizeof(T) == sizeof(long) ? Vector128.LessThan(indices1.AsInt64(), indices2.AsInt64()).As() : - sizeof(T) == sizeof(int) ? Vector128.LessThan(indices1.AsInt32(), indices2.AsInt32()).As() : - sizeof(T) == sizeof(short) ? Vector128.LessThan(indices1.AsInt16(), indices2.AsInt16()).As() : - Vector128.LessThan(indices1.AsByte(), indices2.AsByte()).As(); + private static int IndexOfMinMaxVector512Size1(ReadOnlySpan x) + where T : INumber where TOperator : struct, IIndexOfMinMaxOperator + { + Debug.Assert(sizeof(T) == 1); + + // Initialize result by reading first vector and quick return if possible. + Vector512 result = Vector512.Create(x); + Vector512 mask = TOperator.IsQuickReturn(result); + if (mask != Vector512.Zero) + { + return IndexOfFirstMatch(mask); + } + + // Initialize indices. + Vector512 indexIncrement = Vector512.Create((uint)Vector512.Count); + Vector512 resultIndex1 = Vector512.Indices; + Vector512 resultIndex2 = resultIndex1 + indexIncrement; + Vector512 resultIndex3 = resultIndex2 + indexIncrement; + Vector512 resultIndex4 = resultIndex3 + indexIncrement; + Vector512 currentIndex = resultIndex4 + indexIncrement; + ReadOnlySpan span = x.Slice(Vector512.Count); + + while (!span.IsEmpty) + { + Vector512 current; + if (span.Length >= Vector512.Count) + { + current = Vector512.Create(span); + span = span.Slice(Vector512.Count); + } + else + { + // Process a final back-shifted to cover remaining elements in x in one vector. + int start = x.Length - Vector512.Count; + current = Vector512.Create(x.Slice(start)); + currentIndex = Vector512.Create((uint)start) + Vector512.Indices; + span = ReadOnlySpan.Empty; + } + + // Quick return if possible. + mask = TOperator.IsQuickReturn(current); + if (mask != Vector512.Zero) + { + return (int)currentIndex.ToScalar() + IndexOfFirstMatch(mask); + } + + // Get mask for which lanes that should have result updated, also widen it for updating the indices. + mask = TOperator.Compare(current, result); + (Vector512 lowerMask, Vector512 upperMask) = Vector512.Widen(mask.AsSByte()); + (Vector512 mask1, Vector512 mask2) = Vector512.Widen(lowerMask); + (Vector512 mask3, Vector512 mask4) = Vector512.Widen(upperMask); + + // Update result and indices. + result = ElementWiseSelect(mask, current, result); + resultIndex1 = ElementWiseSelect(mask1.AsUInt32(), currentIndex, resultIndex1); + currentIndex += indexIncrement; + resultIndex2 = ElementWiseSelect(mask2.AsUInt32(), currentIndex, resultIndex2); + currentIndex += indexIncrement; + resultIndex3 = ElementWiseSelect(mask3.AsUInt32(), currentIndex, resultIndex3); + currentIndex += indexIncrement; + resultIndex4 = ElementWiseSelect(mask4.AsUInt32(), currentIndex, resultIndex4); + currentIndex += indexIncrement; + } + + { + // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. + T aggResult = TOperator.Aggregate(result); + Vector512 aggMask = ~Vector512.Equals(result.AsSByte(), Vector512.Create(aggResult).AsSByte()); + + (Vector512 lowerMask, Vector512 upperMask) = Vector512.Widen(aggMask); + (Vector512 mask1, Vector512 mask2) = Vector512.Widen(lowerMask); + (Vector512 mask3, Vector512 mask4) = Vector512.Widen(upperMask); + Vector512 aggIndex = resultIndex1 | mask1.AsUInt32(); + aggIndex = MinOperator.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); + aggIndex = MinOperator.Invoke(aggIndex, resultIndex3 | mask3.AsUInt32()); + aggIndex = MinOperator.Invoke(aggIndex, resultIndex4 | mask4.AsUInt32()); + + return (int)HorizontalAggregate>(aggIndex); + } + } } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMax.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMax.cs index f40f7e1e2e2ba0..12292dac6b9a7b 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMax.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMax.cs @@ -1,9 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Diagnostics; using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; using System.Runtime.Intrinsics; using System.Runtime.Intrinsics.X86; @@ -29,394 +27,75 @@ public static int IndexOfMax(ReadOnlySpan x) IndexOfMinMaxCore>(x); /// Returns the index of MathF.Max(x, y) - internal readonly struct IndexOfMaxOperator : IIndexOfOperator where T : INumber + internal readonly struct IndexOfMaxOperator : IIndexOfMinMaxOperator where T : INumber { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 currentIndex) - { - Vector128 useResult = Vector128.GreaterThan(result, current); - Vector128 equalMask = Vector128.Equals(result, current); - - if (equalMask != Vector128.Zero) - { - Vector128 lessThanIndexMask = IndexLessThan(resultIndex, currentIndex); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(current)); - Vector128 currentNegative = IsNegative(current); - Vector128 sameSign = Vector128.Equals(IsNegative(result).AsInt32(), currentNegative.AsInt32()).As(); - useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, currentNegative); - } - else - { - useResult |= equalMask & lessThanIndexMask; - } - } - - result = ElementWiseSelect(useResult, result, current); - resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex); - } + public static T Aggregate(Vector128 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector256 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector512 x) => HorizontalAggregate>(x); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 currentIndex) - { - Vector256 useResult = Vector256.GreaterThan(result, current); - Vector256 equalMask = Vector256.Equals(result, current); - - if (equalMask != Vector256.Zero) - { - Vector256 lessThanIndexMask = IndexLessThan(resultIndex, currentIndex); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(current)); - Vector256 currentNegative = IsNegative(current); - Vector256 sameSign = Vector256.Equals(IsNegative(result).AsInt32(), currentNegative.AsInt32()).As(); - useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, currentNegative); - } - else - { - useResult |= equalMask & lessThanIndexMask; - } - } - - result = ElementWiseSelect(useResult, result, current); - resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex); - } - + public static bool IsQuickReturn(T value) => T.IsNaN(value); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 currentIndex) - { - Vector512 useResult = Vector512.GreaterThan(result, current); - Vector512 equalMask = Vector512.Equals(result, current); - - if (equalMask != Vector512.Zero) - { - Vector512 lessThanIndexMask = IndexLessThan(resultIndex, currentIndex); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(current)); - Vector512 currentNegative = IsNegative(current); - Vector512 sameSign = Vector512.Equals(IsNegative(result).AsInt32(), currentNegative.AsInt32()).As(); - useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, currentNegative); - } - else - { - useResult |= equalMask & lessThanIndexMask; - } - } - - result = ElementWiseSelect(useResult, result, current); - resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex); - } + public static Vector128 IsQuickReturn(Vector128 value) => IsNaN(value); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 IsQuickReturn(Vector256 value) => IsNaN(value); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 IsQuickReturn(Vector512 value) => IsNaN(value); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static int Invoke(ref T result, T current, int resultIndex, int currentIndex) + public static bool Compare(T x, T y) { - if (result == current) + if (x == y) { - bool resultNegative = IsNegative(result); - if ((resultNegative == IsNegative(current)) ? (currentIndex < resultIndex) : resultNegative) - { - result = current; - return currentIndex; - } + return T.IsPositive(x) && T.IsNegative(y); } - else if (current > result) + else { - result = current; - return currentIndex; + return x > y; } - - return resultIndex; } - } - - private static unsafe int IndexOfMinMaxCore(ReadOnlySpan x) - where T : INumber - where TIndexOfMinMax : struct, IIndexOfOperator - { - if (x.IsEmpty) - { - return -1; - } - - // This matches the IEEE 754:2019 `maximum`/`minimum` functions. - // It propagates NaN inputs back to the caller and - // otherwise returns the index of the greater of the inputs. - // It treats +0 as greater than -0 as per the specification. - if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && x.Length >= Vector512.Count) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Compare(Vector128 x, Vector128 y) { - Debug.Assert(sizeof(T) is 1 or 2 or 4 or 8); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - static Vector512 CreateVector512T(int i) => - sizeof(T) == sizeof(long) ? Vector512.Create((long)i).As() : - sizeof(T) == sizeof(int) ? Vector512.Create(i).As() : - sizeof(T) == sizeof(short) ? Vector512.Create((short)i).As() : - Vector512.Create((byte)i).As(); - - ref T xRef = ref MemoryMarshal.GetReference(x); - Vector512 resultIndex = - sizeof(T) == sizeof(long) ? Vector512.Indices.As() : - sizeof(T) == sizeof(int) ? Vector512.Indices.As() : - sizeof(T) == sizeof(short) ? Vector512.Indices.As() : - Vector512.Indices.As(); - Vector512 currentIndex = resultIndex; - Vector512 increment = CreateVector512T(Vector512.Count); - - // Load the first vector as the initial set of results, and bail immediately - // to scalar handling if it contains any NaNs (which don't compare equally to themselves). - Vector512 result = Vector512.LoadUnsafe(ref xRef); - Vector512 current; - - Vector512 nanMask; - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) { - nanMask = ~Vector512.Equals(result, result); - if (nanMask != Vector512.Zero) - { - return IndexOfFirstMatch(nanMask); - } + Vector128 equalResult = IsPositive(x) & IsNegative(y); + return Vector128.GreaterThan(x, y) | (Vector128.Equals(x, y) & equalResult); } - - int oneVectorFromEnd = x.Length - Vector512.Count; - int i = Vector512.Count; - - // Aggregate additional vectors into the result as long as there's at least one full vector left to process. - while (i <= oneVectorFromEnd) + else { - // Load the next vector, and early exit on NaN. - current = Vector512.LoadUnsafe(ref xRef, (uint)i); - currentIndex += increment; - - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - nanMask = ~Vector512.Equals(current, current); - if (nanMask != Vector512.Zero) - { - return i + IndexOfFirstMatch(nanMask); - } - } - - TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, currentIndex); - - i += Vector512.Count; + return Vector128.GreaterThan(x, y); } - - // If any elements remain, handle them in one final vector. - if (i != x.Length) - { - current = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); - currentIndex += CreateVector512T(x.Length - i); - - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - nanMask = ~Vector512.Equals(current, current); - if (nanMask != Vector512.Zero) - { - int indexInVectorOfFirstMatch = IndexOfFirstMatch(nanMask); - return typeof(T) == typeof(double) ? - (int)(long)(object)currentIndex.As()[indexInVectorOfFirstMatch] : - (int)(object)currentIndex.As()[indexInVectorOfFirstMatch]; - } - } - - TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, currentIndex); - } - - // Aggregate the lanes in the vector to create the final scalar result. - return IndexOfFinalAggregate(result, resultIndex); } - if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && x.Length >= Vector256.Count) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Compare(Vector256 x, Vector256 y) { - Debug.Assert(sizeof(T) is 1 or 2 or 4 or 8); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - static Vector256 CreateVector256T(int i) => - sizeof(T) == sizeof(long) ? Vector256.Create((long)i).As() : - sizeof(T) == sizeof(int) ? Vector256.Create(i).As() : - sizeof(T) == sizeof(short) ? Vector256.Create((short)i).As() : - Vector256.Create((byte)i).As(); - - ref T xRef = ref MemoryMarshal.GetReference(x); - Vector256 resultIndex = - sizeof(T) == sizeof(long) ? Vector256.Indices.As() : - sizeof(T) == sizeof(int) ? Vector256.Indices.As() : - sizeof(T) == sizeof(short) ? Vector256.Indices.As() : - Vector256.Indices.As(); - Vector256 currentIndex = resultIndex; - Vector256 increment = CreateVector256T(Vector256.Count); - - // Load the first vector as the initial set of results, and bail immediately - // to scalar handling if it contains any NaNs (which don't compare equally to themselves). - Vector256 result = Vector256.LoadUnsafe(ref xRef); - Vector256 current; - - Vector256 nanMask; - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) { - nanMask = ~Vector256.Equals(result, result); - if (nanMask != Vector256.Zero) - { - return IndexOfFirstMatch(nanMask); - } + Vector256 equalResult = IsPositive(x) & IsNegative(y); + return Vector256.GreaterThan(x, y) | (Vector256.Equals(x, y) & equalResult); } - - int oneVectorFromEnd = x.Length - Vector256.Count; - int i = Vector256.Count; - - // Aggregate additional vectors into the result as long as there's at least one full vector left to process. - while (i <= oneVectorFromEnd) + else { - // Load the next vector, and early exit on NaN. - current = Vector256.LoadUnsafe(ref xRef, (uint)i); - currentIndex += increment; - - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - nanMask = ~Vector256.Equals(current, current); - if (nanMask != Vector256.Zero) - { - return i + IndexOfFirstMatch(nanMask); - } - } - - TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, currentIndex); - - i += Vector256.Count; + return Vector256.GreaterThan(x, y); } - - // If any elements remain, handle them in one final vector. - if (i != x.Length) - { - current = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); - currentIndex += CreateVector256T(x.Length - i); - - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - nanMask = ~Vector256.Equals(current, current); - if (nanMask != Vector256.Zero) - { - int indexInVectorOfFirstMatch = IndexOfFirstMatch(nanMask); - return typeof(T) == typeof(double) ? - (int)(long)(object)currentIndex.As()[indexInVectorOfFirstMatch] : - (int)(object)currentIndex.As()[indexInVectorOfFirstMatch]; - } - } - - TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, currentIndex); - } - - // Aggregate the lanes in the vector to create the final scalar result. - return IndexOfFinalAggregate(result, resultIndex); } - if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && x.Length >= Vector128.Count) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Compare(Vector512 x, Vector512 y) { - Debug.Assert(sizeof(T) is 1 or 2 or 4 or 8); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - static Vector128 CreateVector128T(int i) => - sizeof(T) == sizeof(long) ? Vector128.Create((long)i).As() : - sizeof(T) == sizeof(int) ? Vector128.Create(i).As() : - sizeof(T) == sizeof(short) ? Vector128.Create((short)i).As() : - Vector128.Create((byte)i).As(); - - ref T xRef = ref MemoryMarshal.GetReference(x); - Vector128 resultIndex = - sizeof(T) == sizeof(long) ? Vector128.Indices.As() : - sizeof(T) == sizeof(int) ? Vector128.Indices.As() : - sizeof(T) == sizeof(short) ? Vector128.Indices.As() : - Vector128.Indices.As(); - Vector128 currentIndex = resultIndex; - Vector128 increment = CreateVector128T(Vector128.Count); - - // Load the first vector as the initial set of results, and bail immediately - // to scalar handling if it contains any NaNs (which don't compare equally to themselves). - Vector128 result = Vector128.LoadUnsafe(ref xRef); - Vector128 current; - - Vector128 nanMask; - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - nanMask = ~Vector128.Equals(result, result); - if (nanMask != Vector128.Zero) - { - return IndexOfFirstMatch(nanMask); - } - } - - int oneVectorFromEnd = x.Length - Vector128.Count; - int i = Vector128.Count; - - // Aggregate additional vectors into the result as long as there's at least one full vector left to process. - while (i <= oneVectorFromEnd) + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) { - // Load the next vector, and early exit on NaN. - current = Vector128.LoadUnsafe(ref xRef, (uint)i); - currentIndex += increment; - - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - nanMask = ~Vector128.Equals(current, current); - if (nanMask != Vector128.Zero) - { - return i + IndexOfFirstMatch(nanMask); - } - } - - TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, currentIndex); - - i += Vector128.Count; + Vector512 equalResult = IsPositive(x) & IsNegative(y); + return Vector512.GreaterThan(x, y) | (Vector512.Equals(x, y) & equalResult); } - - // If any elements remain, handle them in one final vector. - if (i != x.Length) + else { - current = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); - currentIndex += CreateVector128T(x.Length - i); - - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - nanMask = ~Vector128.Equals(current, current); - if (nanMask != Vector128.Zero) - { - int indexInVectorOfFirstMatch = IndexOfFirstMatch(nanMask); - return typeof(T) == typeof(double) ? - (int)(long)(object)currentIndex.As()[indexInVectorOfFirstMatch] : - (int)(object)currentIndex.As()[indexInVectorOfFirstMatch]; - } - } - - TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, currentIndex); + return Vector512.GreaterThan(x, y); } - - // Aggregate the lanes in the vector to create the final scalar result. - return IndexOfFinalAggregate(result, resultIndex); - } - - // Scalar path used when either vectorization is not supported or the input is too small to vectorize. - T curResult = x[0]; - int curIn = 0; - if (T.IsNaN(curResult)) - { - return curIn; } - - for (int i = 1; i < x.Length; i++) - { - T current = x[i]; - if (T.IsNaN(current)) - { - return i; - } - - curIn = TIndexOfMinMax.Invoke(ref curResult, current, curIn, i); - } - - return curIn; } private static int IndexOfFirstMatch(Vector128 mask) => @@ -428,23 +107,6 @@ private static int IndexOfFirstMatch(Vector256 mask) => private static int IndexOfFirstMatch(Vector512 mask) => BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector256 IndexLessThan(Vector256 indices1, Vector256 indices2) => - sizeof(T) == sizeof(long) ? Vector256.LessThan(indices1.AsInt64(), indices2.AsInt64()).As() : - sizeof(T) == sizeof(int) ? Vector256.LessThan(indices1.AsInt32(), indices2.AsInt32()).As() : - sizeof(T) == sizeof(short) ? Vector256.LessThan(indices1.AsInt16(), indices2.AsInt16()).As() : - Vector256.LessThan(indices1.AsByte(), indices2.AsByte()).As(); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector512 IndexLessThan(Vector512 indices1, Vector512 indices2) => - sizeof(T) == sizeof(long) ? Vector512.LessThan(indices1.AsInt64(), indices2.AsInt64()).As() : - sizeof(T) == sizeof(int) ? Vector512.LessThan(indices1.AsInt32(), indices2.AsInt32()).As() : - sizeof(T) == sizeof(short) ? Vector512.LessThan(indices1.AsInt16(), indices2.AsInt16()).As() : - Vector512.LessThan(indices1.AsByte(), indices2.AsByte()).As(); - - /// Gets whether the specified is negative. - private static bool IsNegative(T f) where T : INumberBase => T.IsNegative(f); - [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector128 ElementWiseSelect(Vector128 mask, Vector128 left, Vector128 right) { diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxMagnitude.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxMagnitude.cs index f1f5016a86b13e..a4e16d4f35ece3 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxMagnitude.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxMagnitude.cs @@ -26,111 +26,78 @@ public static int IndexOfMaxMagnitude(ReadOnlySpan x) where T : INumber => IndexOfMinMaxCore>(x); - internal readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator where T : INumber + internal readonly struct IndexOfMaxMagnitudeOperator : IIndexOfMinMaxOperator where T : INumber { + public static T Aggregate(Vector128 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector256 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector512 x) => HorizontalAggregate>(x); + [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 currentIndex) - { - Vector128 resultMag = Vector128.Abs(result), currentMag = Vector128.Abs(current); - Vector128 useResult = Vector128.GreaterThan(resultMag, currentMag); - Vector128 equalMask = Vector128.Equals(resultMag, currentMag); + public static bool IsQuickReturn(T value) => T.IsNaN(value); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 IsQuickReturn(Vector128 value) => IsNaN(value); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 IsQuickReturn(Vector256 value) => IsNaN(value); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 IsQuickReturn(Vector512 value) => IsNaN(value); - if (equalMask != Vector128.Zero) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool Compare(T x, T y) + { + T xMag = T.Abs(x), yMag = T.Abs(y); + if (xMag == yMag) { - Vector128 lessThanIndexMask = IndexLessThan(resultIndex, currentIndex); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(current)); - Vector128 currentNegative = IsNegative(current); - Vector128 sameSign = Vector128.Equals(IsNegative(result).AsInt32(), currentNegative.AsInt32()).As(); - useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, currentNegative); - } - else - { - useResult |= equalMask & lessThanIndexMask; - } + return T.IsPositive(x) && T.IsNegative(y); + } + else + { + return xMag > yMag; } - - result = ElementWiseSelect(useResult, result, current); - resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 currentIndex) + public static Vector128 Compare(Vector128 x, Vector128 y) { - Vector256 resultMag = Vector256.Abs(result), currentMag = Vector256.Abs(current); - Vector256 useResult = Vector256.GreaterThan(resultMag, currentMag); - Vector256 equalMask = Vector256.Equals(resultMag, currentMag); - - if (equalMask != Vector256.Zero) + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) { - Vector256 lessThanIndexMask = IndexLessThan(resultIndex, currentIndex); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(current)); - Vector256 currentNegative = IsNegative(current); - Vector256 sameSign = Vector256.Equals(IsNegative(result).AsInt32(), currentNegative.AsInt32()).As(); - useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, currentNegative); - } - else - { - useResult |= equalMask & lessThanIndexMask; - } + Vector128 equalResult = IsPositive(x) & IsNegative(y); + return Vector128.GreaterThan(xMag, yMag) | (Vector128.Equals(xMag, yMag) & equalResult); + } + else + { + return Vector128.GreaterThan(xMag, yMag); } - - result = ElementWiseSelect(useResult, result, current); - resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 currentIndex) + public static Vector256 Compare(Vector256 x, Vector256 y) { - Vector512 resultMag = Vector512.Abs(result), currentMag = Vector512.Abs(current); - Vector512 useResult = Vector512.GreaterThan(resultMag, currentMag); - Vector512 equalMask = Vector512.Equals(resultMag, currentMag); - - if (equalMask != Vector512.Zero) + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) { - Vector512 lessThanIndexMask = IndexLessThan(resultIndex, currentIndex); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(current)); - Vector512 currentNegative = IsNegative(current); - Vector512 sameSign = Vector512.Equals(IsNegative(result).AsInt32(), currentNegative.AsInt32()).As(); - useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, currentNegative); - } - else - { - useResult |= equalMask & lessThanIndexMask; - } + Vector256 equalResult = IsPositive(x) & IsNegative(y); + return Vector256.GreaterThan(xMag, yMag) | (Vector256.Equals(xMag, yMag) & equalResult); + } + else + { + return Vector256.GreaterThan(xMag, yMag); } - - result = ElementWiseSelect(useResult, result, current); - resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static int Invoke(ref T result, T current, int resultIndex, int currentIndex) + public static Vector512 Compare(Vector512 x, Vector512 y) { - T resultMag = T.Abs(result); - T currentMag = T.Abs(current); - - if (resultMag == currentMag) + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) { - bool resultNegative = IsNegative(result); - if ((resultNegative == IsNegative(current)) ? (currentIndex < resultIndex) : resultNegative) - { - result = current; - return currentIndex; - } + Vector512 equalResult = IsPositive(x) & IsNegative(y); + return Vector512.GreaterThan(xMag, yMag) | (Vector512.Equals(xMag, yMag) & equalResult); } - else if (currentMag > resultMag) + else { - result = current; - return currentIndex; + return Vector512.GreaterThan(xMag, yMag); } - - return resultIndex; } } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMin.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMin.cs index 011021b6c0015f..c69e5727147ecd 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMin.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMin.cs @@ -26,105 +26,74 @@ public static int IndexOfMin(ReadOnlySpan x) IndexOfMinMaxCore>(x); /// Returns the index of MathF.Min(x, y) - internal readonly struct IndexOfMinOperator : IIndexOfOperator where T : INumber + internal readonly struct IndexOfMinOperator : IIndexOfMinMaxOperator where T : INumber { + public static T Aggregate(Vector128 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector256 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector512 x) => HorizontalAggregate>(x); + [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 currentIndex) - { - Vector128 useResult = Vector128.LessThan(result, current); - Vector128 equalMask = Vector128.Equals(result, current); + public static bool IsQuickReturn(T value) => T.IsNaN(value); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 IsQuickReturn(Vector128 value) => IsNaN(value); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 IsQuickReturn(Vector256 value) => IsNaN(value); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 IsQuickReturn(Vector512 value) => IsNaN(value); - if (equalMask != Vector128.Zero) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool Compare(T x, T y) + { + if (x == y) { - Vector128 lessThanIndexMask = IndexLessThan(resultIndex, currentIndex); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(result)); - Vector128 resultNegative = IsNegative(result); - Vector128 sameSign = Vector128.Equals(resultNegative.AsInt32(), IsNegative(current).AsInt32()).As(); - useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, resultNegative); - } - else - { - useResult |= equalMask & lessThanIndexMask; - } + return T.IsNegative(x) && T.IsPositive(y); + } + else + { + return x < y; } - - result = ElementWiseSelect(useResult, result, current); - resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 currentIndex) + public static Vector128 Compare(Vector128 x, Vector128 y) { - Vector256 useResult = Vector256.LessThan(result, current); - Vector256 equalMask = Vector256.Equals(result, current); - - if (equalMask != Vector256.Zero) + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) { - Vector256 lessThanIndexMask = IndexLessThan(resultIndex, currentIndex); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(result)); - Vector256 resultNegative = IsNegative(result); - Vector256 sameSign = Vector256.Equals(resultNegative.AsInt32(), IsNegative(current).AsInt32()).As(); - useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, resultNegative); - } - else - { - useResult |= equalMask & lessThanIndexMask; - } + Vector128 equalResult = IsNegative(x) & IsPositive(y); + return Vector128.LessThan(x, y) | (Vector128.Equals(x, y) & equalResult); + } + else + { + return Vector128.LessThan(x, y); } - - result = ElementWiseSelect(useResult, result, current); - resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 currentIndex) + public static Vector256 Compare(Vector256 x, Vector256 y) { - Vector512 useResult = Vector512.LessThan(result, current); - Vector512 equalMask = Vector512.Equals(result, current); - - if (equalMask != Vector512.Zero) + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) { - Vector512 lessThanIndexMask = IndexLessThan(resultIndex, currentIndex); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(result)); - Vector512 resultNegative = IsNegative(result); - Vector512 sameSign = Vector512.Equals(resultNegative.AsInt32(), IsNegative(current).AsInt32()).As(); - useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, resultNegative); - } - else - { - useResult |= equalMask & lessThanIndexMask; - } + Vector256 equalResult = IsNegative(x) & IsPositive(y); + return Vector256.LessThan(x, y) | (Vector256.Equals(x, y) & equalResult); + } + else + { + return Vector256.LessThan(x, y); } - - result = ElementWiseSelect(useResult, result, current); - resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static int Invoke(ref T result, T current, int resultIndex, int currentIndex) + public static Vector512 Compare(Vector512 x, Vector512 y) { - if (result == current) + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) { - bool currentNegative = IsNegative(current); - if ((IsNegative(result) == currentNegative) ? (currentIndex < resultIndex) : currentNegative) - { - result = current; - return currentIndex; - } + Vector512 equalResult = IsNegative(x) & IsPositive(y); + return Vector512.LessThan(x, y) | (Vector512.Equals(x, y) & equalResult); } - else if (current < result) + else { - result = current; - return currentIndex; + return Vector512.LessThan(x, y); } - - return resultIndex; } } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinMagnitude.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinMagnitude.cs index 813bcf4637dd12..3682382a3cb0d4 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinMagnitude.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinMagnitude.cs @@ -26,111 +26,78 @@ public static int IndexOfMinMagnitude(ReadOnlySpan x) where T : INumber => IndexOfMinMaxCore>(x); - internal readonly struct IndexOfMinMagnitudeOperator : IIndexOfOperator where T : INumber + internal readonly struct IndexOfMinMagnitudeOperator : IIndexOfMinMaxOperator where T : INumber { + public static T Aggregate(Vector128 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector256 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector512 x) => HorizontalAggregate>(x); + [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 currentIndex) - { - Vector128 resultMag = Vector128.Abs(result), currentMag = Vector128.Abs(current); - Vector128 useResult = Vector128.LessThan(resultMag, currentMag); - Vector128 equalMask = Vector128.Equals(resultMag, currentMag); + public static bool IsQuickReturn(T value) => T.IsNaN(value); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 IsQuickReturn(Vector128 value) => IsNaN(value); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 IsQuickReturn(Vector256 value) => IsNaN(value); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 IsQuickReturn(Vector512 value) => IsNaN(value); - if (equalMask != Vector128.Zero) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool Compare(T x, T y) + { + T xMag = T.Abs(x), yMag = T.Abs(y); + if (xMag == yMag) { - Vector128 lessThanIndexMask = IndexLessThan(resultIndex, currentIndex); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(result)); - Vector128 resultNegative = IsNegative(result); - Vector128 sameSign = Vector128.Equals(resultNegative.AsInt32(), IsNegative(current).AsInt32()).As(); - useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, resultNegative); - } - else - { - useResult |= equalMask & lessThanIndexMask; - } + return T.IsNegative(x) && T.IsPositive(y); + } + else + { + return xMag < yMag; } - - result = ElementWiseSelect(useResult, result, current); - resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 currentIndex) + public static Vector128 Compare(Vector128 x, Vector128 y) { - Vector256 resultMag = Vector256.Abs(result), currentMag = Vector256.Abs(current); - Vector256 useResult = Vector256.LessThan(resultMag, currentMag); - Vector256 equalMask = Vector256.Equals(resultMag, currentMag); - - if (equalMask != Vector256.Zero) + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) { - Vector256 lessThanIndexMask = IndexLessThan(resultIndex, currentIndex); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(result)); - Vector256 resultNegative = IsNegative(result); - Vector256 sameSign = Vector256.Equals(resultNegative.AsInt32(), IsNegative(current).AsInt32()).As(); - useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, resultNegative); - } - else - { - useResult |= equalMask & lessThanIndexMask; - } + Vector128 equalResult = IsNegative(x) & IsPositive(y); + return Vector128.LessThan(xMag, yMag) | (Vector128.Equals(xMag, yMag) & equalResult); + } + else + { + return Vector128.LessThan(xMag, yMag); } - - result = ElementWiseSelect(useResult, result, current); - resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 currentIndex) + public static Vector256 Compare(Vector256 x, Vector256 y) { - Vector512 resultMag = Vector512.Abs(result), currentMag = Vector512.Abs(current); - Vector512 useResult = Vector512.LessThan(resultMag, currentMag); - Vector512 equalMask = Vector512.Equals(resultMag, currentMag); - - if (equalMask != Vector512.Zero) + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) { - Vector512 lessThanIndexMask = IndexLessThan(resultIndex, currentIndex); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) - { - // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(result)); - Vector512 resultNegative = IsNegative(result); - Vector512 sameSign = Vector512.Equals(resultNegative.AsInt32(), IsNegative(current).AsInt32()).As(); - useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, resultNegative); - } - else - { - useResult |= equalMask & lessThanIndexMask; - } + Vector256 equalResult = IsNegative(x) & IsPositive(y); + return Vector256.LessThan(xMag, yMag) | (Vector256.Equals(xMag, yMag) & equalResult); + } + else + { + return Vector256.LessThan(xMag, yMag); } - - result = ElementWiseSelect(useResult, result, current); - resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static int Invoke(ref T result, T current, int resultIndex, int currentIndex) + public static Vector512 Compare(Vector512 x, Vector512 y) { - T resultMag = T.Abs(result); - T currentMag = T.Abs(current); - - if (resultMag == currentMag) + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) { - bool currentNegative = IsNegative(current); - if ((IsNegative(result) == currentNegative) ? (currentIndex < resultIndex) : currentNegative) - { - result = current; - return currentIndex; - } + Vector512 equalResult = IsNegative(x) & IsPositive(y); + return Vector512.LessThan(xMag, yMag) | (Vector512.Equals(xMag, yMag) & equalResult); } - else if (currentMag < resultMag) + else { - result = current; - return currentIndex; + return Vector512.LessThan(xMag, yMag); } - - return resultIndex; } } } diff --git a/src/libraries/System.Numerics.Tensors/tests/Helpers.cs b/src/libraries/System.Numerics.Tensors/tests/Helpers.cs index 4e5e89bbb3cad8..ef8ae1406a1979 100644 --- a/src/libraries/System.Numerics.Tensors/tests/Helpers.cs +++ b/src/libraries/System.Numerics.Tensors/tests/Helpers.cs @@ -10,6 +10,9 @@ namespace System.Numerics.Tensors.Tests { public static class Helpers { + public static int SizeGraterThanByte => 260; + public static int SizeGraterThanInt16 => 65540; + public static IEnumerable TensorLengthsIncluding0 => Enumerable.Range(0, 257); public static IEnumerable TensorLengths => Enumerable.Range(1, 256); diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs index f505a2ab77eb1f..35b294e157f570 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs @@ -320,39 +320,100 @@ private static void ConvertToIntegerNativeImpl() // The tests for some types have been marked as OuterLoop simply to decrease inner loop testing time. - public class DoubleGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests { } - public class SingleGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests { } + public class DoubleGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => null; + } + + public class SingleGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => null; + } + public class HalfGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests { + protected override int? IndexOfSizeExceedingMaxValue() => Helpers.SizeGraterThanInt16; protected override void AssertEqualTolerance(Half expected, Half actual, Half? tolerance = null) => base.AssertEqualTolerance(expected, actual, tolerance ?? Half.CreateTruncating(0.001)); } [OuterLoop] - public class NFloatGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests { } + public class NFloatGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => null; + } [OuterLoop] - public class SByteGenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests { } - public class Int16GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests { } + public class SByteGenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => Helpers.SizeGraterThanByte; + } + + public class Int16GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => Helpers.SizeGraterThanInt16; + } + [OuterLoop] - public class Int32GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests { } - public class Int64GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests { } + public class Int32GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => null; + } + + public class Int64GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => null; + } + [OuterLoop] - public class IntPtrGenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests { } - public class Int128GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests { } + public class IntPtrGenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => null; + } + + public class Int128GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => null; + } + + public class ByteGenericTensorPrimitives : GenericIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => Helpers.SizeGraterThanByte; + } - public class ByteGenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } [OuterLoop] - public class UInt16GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } + public class UInt16GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => Helpers.SizeGraterThanInt16; + } + [OuterLoop] - public class CharGenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } - public class UInt32GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } + public class CharGenericTensorPrimitives : GenericIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => Helpers.SizeGraterThanInt16; + } + + public class UInt32GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => null; + } + [OuterLoop] - public class UInt64GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } + public class UInt64GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => null; + } + + public class UIntPtrGenericTensorPrimitives : GenericIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => null; + } - public class UIntPtrGenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } [OuterLoop] - public class UInt128GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests { } + public class UInt128GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests + { + protected override int? IndexOfSizeExceedingMaxValue() => null; + } public unsafe abstract class GenericFloatingPointNumberTensorPrimitivesTests : GenericNumberTensorPrimitivesTests where T : unmanaged, IFloatingPointIeee754, IMinMaxValue diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs index 707e27d2fae419..d7279b3cb0603b 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs @@ -91,6 +91,8 @@ protected override float MinMagnitude(float x, float y) protected override float NegativeOne => -1f; protected override float MinValue => float.MinValue; + protected override int? IndexOfSizeExceedingMaxValue() => null; + protected override IEnumerable<(int Length, float Element)> VectorLengthAndIteratedRange(float min, float max, float increment) { foreach (int length in new[] { 4, 8, 16 }) diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index b21666434904f1..ef1bb01e82a402 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -106,6 +106,8 @@ public abstract class TensorPrimitivesTests where T : unmanaged, IEquatable typeof(T) == typeof(float) || typeof(T) == typeof(double); + protected abstract int? IndexOfSizeExceedingMaxValue(); + protected abstract T ConvertFromSingle(float f); protected abstract IEnumerable GetSpecialValues(); @@ -1134,6 +1136,18 @@ public void IndexOfMax_Negative0LesserThanPositive0() Assert.Equal(1, IndexOfMax([ConvertFromSingle(-1), ConvertFromSingle(-0f)])); Assert.Equal(2, IndexOfMax([ConvertFromSingle(-1), ConvertFromSingle(-0f), ConvertFromSingle(1f)])); } + + [Fact] + public void IndexOfMax_IndexAboveMaxValue() + { + var size = IndexOfSizeExceedingMaxValue(); + if (size == null) return; + + using BoundedMemory x = CreateTensor(size.Value); + x.Span.Fill(ConvertFromSingle(1)); + x.Span[size.Value - 1] = ConvertFromSingle(2); + Assert.Equal(size.Value - 1, IndexOfMax(x)); + } #endregion #region IndexOfMaxMagnitude @@ -1211,6 +1225,18 @@ public void IndexOfMaxMagnitude_Negative0LesserThanPositive0() Assert.Equal(0, IndexOfMaxMagnitude([ConvertFromSingle(-1), ConvertFromSingle(-0f)])); Assert.Equal(2, IndexOfMaxMagnitude([ConvertFromSingle(-1), ConvertFromSingle(-0f), ConvertFromSingle(1f)])); } + + [Fact] + public void IndexOfMaxMagnitude_IndexAboveMaxValue() + { + var size = IndexOfSizeExceedingMaxValue(); + if (size == null) return; + + using BoundedMemory x = CreateTensor(size.Value); + x.Span.Fill(ConvertFromSingle(1)); + x.Span[size.Value - 1] = ConvertFromSingle(2); + Assert.Equal(size.Value - 1, IndexOfMaxMagnitude(x)); + } #endregion #region IndexOfMin @@ -1263,6 +1289,18 @@ public void IndexOfMin_Negative0LesserThanPositive0() Assert.Equal(0, IndexOfMin([ConvertFromSingle(-1), ConvertFromSingle(-0f)])); Assert.Equal(0, IndexOfMin([ConvertFromSingle(-1), ConvertFromSingle(-0f), ConvertFromSingle(1f)])); } + + [Fact] + public void IndexOfMin_IndexAboveMaxValue() + { + var size = IndexOfSizeExceedingMaxValue(); + if (size == null) return; + + using BoundedMemory x = CreateTensor(size.Value); + x.Span.Fill(ConvertFromSingle(1)); + x.Span[size.Value - 1] = ConvertFromSingle(0); + Assert.Equal(size.Value - 1, IndexOfMin(x)); + } #endregion #region IndexOfMinMagnitude @@ -1340,6 +1378,18 @@ public void IndexOfMinMagnitude_Negative0LesserThanPositive0() Assert.Equal(1, IndexOfMinMagnitude([ConvertFromSingle(-1), ConvertFromSingle(-0f)])); Assert.Equal(1, IndexOfMinMagnitude([ConvertFromSingle(-1), ConvertFromSingle(-0f), ConvertFromSingle(1f)])); } + + [Fact] + public void IndexOfMinMagnitude_IndexAboveMaxValue() + { + var size = IndexOfSizeExceedingMaxValue(); + if (size == null) return; + + using BoundedMemory x = CreateTensor(size.Value); + x.Span.Fill(ConvertFromSingle(1)); + x.Span[size.Value - 1] = ConvertFromSingle(0); + Assert.Equal(size.Value - 1, IndexOfMinMagnitude(x)); + } #endregion #region Log