diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IIndexOfOperator.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IIndexOfOperator.cs index 09492ac64eb075..ede21d91865ba1 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IIndexOfOperator.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IIndexOfOperator.cs @@ -53,6 +53,9 @@ private static int IndexOfFinalAggregate(Vector128 resul if (sizeof(T) == 2) { + // For short/ushort types, use unsigned comparison for index ordering + // This allows indices up to 65535 to be compared correctly even when + // stored as signed short (which wraps negative for values > 32767) // Compare 0,1,2,3 with 4,5,6,7 tmpResult = Vector128.Shuffle(result.AsInt16(), Vector128.Create(4, 5, 6, 7, 0, 1, 2, 3)).As(); tmpIndex = Vector128.Shuffle(resultIndex.AsInt16(), Vector128.Create(4, 5, 6, 7, 0, 1, 2, 3)).As(); @@ -68,8 +71,8 @@ private static int IndexOfFinalAggregate(Vector128 resul tmpIndex = Vector128.Shuffle(resultIndex.AsInt16(), Vector128.Create(1, 0, 2, 3, 4, 5, 6, 7)).As(); TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); - // Return 0 - return resultIndex.As().ToScalar(); + // Return 0 - interpret as unsigned to handle overflow correctly + return (int)(ushort)resultIndex.As().ToScalar(); } Debug.Assert(sizeof(T) == 1); @@ -94,8 +97,8 @@ private static int IndexOfFinalAggregate(Vector128 resul tmpIndex = Vector128.Shuffle(resultIndex.AsByte(), Vector128.Create((byte)1, 0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As(); TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); - // Return 0 - return resultIndex.As().ToScalar(); + // Return 0 - explicitly cast to int for consistency + return (int)resultIndex.As().ToScalar(); } } @@ -126,7 +129,7 @@ private static int IndexOfFinalAggregate(Vector512 resul private static Vector128 IndexLessThan(Vector128 indices1, Vector128 indices2) => sizeof(T) == sizeof(long) ? Vector128.LessThan(indices1.AsInt64(), indices2.AsInt64()).As() : sizeof(T) == sizeof(int) ? Vector128.LessThan(indices1.AsInt32(), indices2.AsInt32()).As() : - sizeof(T) == sizeof(short) ? Vector128.LessThan(indices1.AsInt16(), indices2.AsInt16()).As() : + sizeof(T) == sizeof(short) ? Vector128.LessThan(indices1.AsUInt16(), indices2.AsUInt16()).As() : Vector128.LessThan(indices1.AsByte(), indices2.AsByte()).As(); } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMax.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMax.cs index f40f7e1e2e2ba0..598a3d8476261a 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMax.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMax.cs @@ -149,6 +149,16 @@ private static unsafe int IndexOfMinMaxCore(ReadOnlySpan x { Debug.Assert(sizeof(T) is 1 or 2 or 4 or 8); + // For byte and short types, check if the array is large enough to cause index overflow + // If so, fall back to scalar processing to avoid incorrect results + // byte/sbyte: max index 255, ushort: max index 65535, short: max index 32767 + if ((sizeof(T) == 1 && x.Length >= 256) || + (typeof(T) == typeof(short) && x.Length >= 32768) || + (typeof(T) == typeof(ushort) && x.Length >= 65536)) + { + goto ScalarPath; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] static Vector512 CreateVector512T(int i) => sizeof(T) == sizeof(long) ? Vector512.Create((long)i).As() : @@ -233,6 +243,16 @@ static Vector512 CreateVector512T(int i) => { Debug.Assert(sizeof(T) is 1 or 2 or 4 or 8); + // For byte and short types, check if the array is large enough to cause index overflow + // If so, fall back to scalar processing to avoid incorrect results + // byte/sbyte: max index 255, ushort: max index 65535, short: max index 32767 + if ((sizeof(T) == 1 && x.Length >= 256) || + (typeof(T) == typeof(short) && x.Length >= 32768) || + (typeof(T) == typeof(ushort) && x.Length >= 65536)) + { + goto ScalarPath; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] static Vector256 CreateVector256T(int i) => sizeof(T) == sizeof(long) ? Vector256.Create((long)i).As() : @@ -317,6 +337,16 @@ static Vector256 CreateVector256T(int i) => { Debug.Assert(sizeof(T) is 1 or 2 or 4 or 8); + // For byte and short types, check if the array is large enough to cause index overflow + // If so, fall back to scalar processing to avoid incorrect results + // byte/sbyte: max index 255, ushort: max index 65535, short: max index 32767 + if ((sizeof(T) == 1 && x.Length >= 256) || + (typeof(T) == typeof(short) && x.Length >= 32768) || + (typeof(T) == typeof(ushort) && x.Length >= 65536)) + { + goto ScalarPath; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] static Vector128 CreateVector128T(int i) => sizeof(T) == sizeof(long) ? Vector128.Create((long)i).As() : @@ -397,6 +427,7 @@ static Vector128 CreateVector128T(int i) => return IndexOfFinalAggregate(result, resultIndex); } + ScalarPath: // Scalar path used when either vectorization is not supported or the input is too small to vectorize. T curResult = x[0]; int curIn = 0; @@ -432,14 +463,14 @@ private static int IndexOfFirstMatch(Vector512 mask) => private static unsafe Vector256 IndexLessThan(Vector256 indices1, Vector256 indices2) => sizeof(T) == sizeof(long) ? Vector256.LessThan(indices1.AsInt64(), indices2.AsInt64()).As() : sizeof(T) == sizeof(int) ? Vector256.LessThan(indices1.AsInt32(), indices2.AsInt32()).As() : - sizeof(T) == sizeof(short) ? Vector256.LessThan(indices1.AsInt16(), indices2.AsInt16()).As() : + sizeof(T) == sizeof(short) ? Vector256.LessThan(indices1.AsUInt16(), indices2.AsUInt16()).As() : Vector256.LessThan(indices1.AsByte(), indices2.AsByte()).As(); [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector512 IndexLessThan(Vector512 indices1, Vector512 indices2) => sizeof(T) == sizeof(long) ? Vector512.LessThan(indices1.AsInt64(), indices2.AsInt64()).As() : sizeof(T) == sizeof(int) ? Vector512.LessThan(indices1.AsInt32(), indices2.AsInt32()).As() : - sizeof(T) == sizeof(short) ? Vector512.LessThan(indices1.AsInt16(), indices2.AsInt16()).As() : + sizeof(T) == sizeof(short) ? Vector512.LessThan(indices1.AsUInt16(), indices2.AsUInt16()).As() : Vector512.LessThan(indices1.AsByte(), indices2.AsByte()).As(); /// Gets whether the specified is negative. diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index b21666434904f1..319d83096339f7 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -1134,6 +1134,53 @@ public void IndexOfMax_Negative0LesserThanPositive0() Assert.Equal(1, IndexOfMax([ConvertFromSingle(-1), ConvertFromSingle(-0f)])); Assert.Equal(2, IndexOfMax([ConvertFromSingle(-1), ConvertFromSingle(-0f), ConvertFromSingle(1f)])); } + +#if !SNT_NET8_TESTS + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.Is64BitProcess))] + public void IndexOfMax_NoIntegerOverflow() + { + if (typeof(T) == typeof(byte)) + { + byte[] data = new byte[258]; + for (int i = 0; i < data.Length; i++) + { + data[i] = (byte)(i % 256); + } + data[257] = 255; + Assert.Equal(257, TensorPrimitives.IndexOfMax(data)); + } + else if (typeof(T) == typeof(sbyte)) + { + sbyte[] data = new sbyte[258]; + for (int i = 0; i < data.Length; i++) + { + data[i] = (sbyte)((i % 256) - 128); + } + data[257] = 127; + Assert.Equal(257, TensorPrimitives.IndexOfMax(data)); + } + else if (typeof(T) == typeof(short)) + { + short[] data = new short[32770]; + for (int i = 0; i < data.Length; i++) + { + data[i] = (short)(i % 32768); + } + data[32769] = 32767; + Assert.Equal(32769, TensorPrimitives.IndexOfMax(data)); + } + else if (typeof(T) == typeof(ushort)) + { + ushort[] data = new ushort[65538]; + for (int i = 0; i < data.Length; i++) + { + data[i] = (ushort)(i % 65536); + } + data[65537] = 65535; + Assert.Equal(65537, TensorPrimitives.IndexOfMax(data)); + } + } +#endif #endregion #region IndexOfMaxMagnitude