Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ private static int IndexOfFinalAggregate<T, TIndexOfOperator>(Vector128<T> resul

if (sizeof(T) == 2)
{
// For short/ushort types, use unsigned comparison for index ordering
// This allows indices up to 65535 to be compared correctly even when
// stored as signed short (which wraps negative for values > 32767)
// Compare 0,1,2,3 with 4,5,6,7
tmpResult = Vector128.Shuffle(result.AsInt16(), Vector128.Create(4, 5, 6, 7, 0, 1, 2, 3)).As<short, T>();
tmpIndex = Vector128.Shuffle(resultIndex.AsInt16(), Vector128.Create(4, 5, 6, 7, 0, 1, 2, 3)).As<short, T>();
Expand All @@ -68,8 +71,8 @@ private static int IndexOfFinalAggregate<T, TIndexOfOperator>(Vector128<T> resul
tmpIndex = Vector128.Shuffle(resultIndex.AsInt16(), Vector128.Create(1, 0, 2, 3, 4, 5, 6, 7)).As<short, T>();
TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);

// Return 0
return resultIndex.As<T, short>().ToScalar();
// Return 0 - interpret as unsigned to handle overflow correctly
return (int)(ushort)resultIndex.As<T, short>().ToScalar();
}

Debug.Assert(sizeof(T) == 1);
Expand All @@ -94,8 +97,8 @@ private static int IndexOfFinalAggregate<T, TIndexOfOperator>(Vector128<T> resul
tmpIndex = Vector128.Shuffle(resultIndex.AsByte(), Vector128.Create((byte)1, 0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As<byte, T>();
TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);

// Return 0
return resultIndex.As<T, byte>().ToScalar();
// Return 0 - explicitly cast to int for consistency
return (int)resultIndex.As<T, byte>().ToScalar();
}
}

Expand Down Expand Up @@ -126,7 +129,7 @@ private static int IndexOfFinalAggregate<T, TIndexOfOperator>(Vector512<T> resul
private static Vector128<T> IndexLessThan<T>(Vector128<T> indices1, Vector128<T> indices2) =>
sizeof(T) == sizeof(long) ? Vector128.LessThan(indices1.AsInt64(), indices2.AsInt64()).As<long, T>() :
sizeof(T) == sizeof(int) ? Vector128.LessThan(indices1.AsInt32(), indices2.AsInt32()).As<int, T>() :
sizeof(T) == sizeof(short) ? Vector128.LessThan(indices1.AsInt16(), indices2.AsInt16()).As<short, T>() :
sizeof(T) == sizeof(short) ? Vector128.LessThan(indices1.AsUInt16(), indices2.AsUInt16()).As<ushort, T>() :
Vector128.LessThan(indices1.AsByte(), indices2.AsByte()).As<byte, T>();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ private static unsafe int IndexOfMinMaxCore<T, TIndexOfMinMax>(ReadOnlySpan<T> x
{
Debug.Assert(sizeof(T) is 1 or 2 or 4 or 8);

// For byte and short types, check if the array is large enough to cause index overflow
// If so, fall back to scalar processing to avoid incorrect results
// byte/sbyte: max index 255, ushort: max index 65535, short: max index 32767
if ((sizeof(T) == 1 && x.Length >= 256) ||
(typeof(T) == typeof(short) && x.Length >= 32768) ||
(typeof(T) == typeof(ushort) && x.Length >= 65536))
{
goto ScalarPath;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static Vector512<T> CreateVector512T(int i) =>
sizeof(T) == sizeof(long) ? Vector512.Create((long)i).As<long, T>() :
Expand Down Expand Up @@ -233,6 +243,16 @@ static Vector512<T> CreateVector512T(int i) =>
{
Debug.Assert(sizeof(T) is 1 or 2 or 4 or 8);

// For byte and short types, check if the array is large enough to cause index overflow
// If so, fall back to scalar processing to avoid incorrect results
// byte/sbyte: max index 255, ushort: max index 65535, short: max index 32767
if ((sizeof(T) == 1 && x.Length >= 256) ||
(typeof(T) == typeof(short) && x.Length >= 32768) ||
(typeof(T) == typeof(ushort) && x.Length >= 65536))
{
goto ScalarPath;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static Vector256<T> CreateVector256T(int i) =>
sizeof(T) == sizeof(long) ? Vector256.Create((long)i).As<long, T>() :
Expand Down Expand Up @@ -317,6 +337,16 @@ static Vector256<T> CreateVector256T(int i) =>
{
Debug.Assert(sizeof(T) is 1 or 2 or 4 or 8);

// For byte and short types, check if the array is large enough to cause index overflow
// If so, fall back to scalar processing to avoid incorrect results
// byte/sbyte: max index 255, ushort: max index 65535, short: max index 32767
if ((sizeof(T) == 1 && x.Length >= 256) ||
(typeof(T) == typeof(short) && x.Length >= 32768) ||
(typeof(T) == typeof(ushort) && x.Length >= 65536))
{
goto ScalarPath;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static Vector128<T> CreateVector128T(int i) =>
sizeof(T) == sizeof(long) ? Vector128.Create((long)i).As<long, T>() :
Expand Down Expand Up @@ -397,6 +427,7 @@ static Vector128<T> CreateVector128T(int i) =>
return IndexOfFinalAggregate<T, TIndexOfMinMax>(result, resultIndex);
}

ScalarPath:
// Scalar path used when either vectorization is not supported or the input is too small to vectorize.
T curResult = x[0];
int curIn = 0;
Expand Down Expand Up @@ -432,14 +463,14 @@ private static int IndexOfFirstMatch<T>(Vector512<T> mask) =>
private static unsafe Vector256<T> IndexLessThan<T>(Vector256<T> indices1, Vector256<T> indices2) =>
sizeof(T) == sizeof(long) ? Vector256.LessThan(indices1.AsInt64(), indices2.AsInt64()).As<long, T>() :
sizeof(T) == sizeof(int) ? Vector256.LessThan(indices1.AsInt32(), indices2.AsInt32()).As<int, T>() :
sizeof(T) == sizeof(short) ? Vector256.LessThan(indices1.AsInt16(), indices2.AsInt16()).As<short, T>() :
sizeof(T) == sizeof(short) ? Vector256.LessThan(indices1.AsUInt16(), indices2.AsUInt16()).As<ushort, T>() :
Vector256.LessThan(indices1.AsByte(), indices2.AsByte()).As<byte, T>();

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector512<T> IndexLessThan<T>(Vector512<T> indices1, Vector512<T> indices2) =>
sizeof(T) == sizeof(long) ? Vector512.LessThan(indices1.AsInt64(), indices2.AsInt64()).As<long, T>() :
sizeof(T) == sizeof(int) ? Vector512.LessThan(indices1.AsInt32(), indices2.AsInt32()).As<int, T>() :
sizeof(T) == sizeof(short) ? Vector512.LessThan(indices1.AsInt16(), indices2.AsInt16()).As<short, T>() :
sizeof(T) == sizeof(short) ? Vector512.LessThan(indices1.AsUInt16(), indices2.AsUInt16()).As<ushort, T>() :
Vector512.LessThan(indices1.AsByte(), indices2.AsByte()).As<byte, T>();

/// <summary>Gets whether the specified <see cref="float"/> is negative.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,53 @@ public void IndexOfMax_Negative0LesserThanPositive0()
Assert.Equal(1, IndexOfMax([ConvertFromSingle(-1), ConvertFromSingle(-0f)]));
Assert.Equal(2, IndexOfMax([ConvertFromSingle(-1), ConvertFromSingle(-0f), ConvertFromSingle(1f)]));
}

#if !SNT_NET8_TESTS
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.Is64BitProcess))]
public void IndexOfMax_NoIntegerOverflow()
{
if (typeof(T) == typeof(byte))
{
byte[] data = new byte[258];
for (int i = 0; i < data.Length; i++)
{
data[i] = (byte)(i % 256);
}
data[257] = 255;
Assert.Equal(257, TensorPrimitives.IndexOfMax<byte>(data));
}
else if (typeof(T) == typeof(sbyte))
{
sbyte[] data = new sbyte[258];
for (int i = 0; i < data.Length; i++)
{
data[i] = (sbyte)((i % 256) - 128);
}
data[257] = 127;
Assert.Equal(257, TensorPrimitives.IndexOfMax<sbyte>(data));
}
else if (typeof(T) == typeof(short))
{
short[] data = new short[32770];
for (int i = 0; i < data.Length; i++)
{
data[i] = (short)(i % 32768);
}
data[32769] = 32767;
Assert.Equal(32769, TensorPrimitives.IndexOfMax<short>(data));
}
else if (typeof(T) == typeof(ushort))
{
ushort[] data = new ushort[65538];
for (int i = 0; i < data.Length; i++)
{
data[i] = (ushort)(i % 65536);
}
data[65537] = 65535;
Assert.Equal(65537, TensorPrimitives.IndexOfMax<ushort>(data));
}
}
#endif
#endregion

#region IndexOfMaxMagnitude
Expand Down
Loading