Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/jit/importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3378,6 +3378,7 @@ GenTree* Compiler::impIntrinsic(GenTree* newobjThis,
bool mustExpand = false;
bool isSpecial = false;
CorInfoIntrinsics intrinsicID = CORINFO_INTRINSIC_Illegal;
NamedIntrinsic ni = NI_Illegal;

if ((methodFlags & CORINFO_FLG_INTRINSIC) != 0)
{
Expand All @@ -3388,6 +3389,20 @@ GenTree* Compiler::impIntrinsic(GenTree* newobjThis,
{
// The recursive calls to Jit intrinsics are must-expand by convention.
mustExpand = mustExpand || gtIsRecursiveCall(method);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're now always expanding calls to HW intrinsics then isn't this comment and logic out of date?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be.

I wasn't sure if there are other JIT intrinsics, which can be recursive, but for which we do not want to always expand.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added this bit just for HW intrinsics.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then should we rename the bit to indicate that it is exclusively for HWIntrinsics (CORINFO_FLG_HW_INTRINSIC) or that this bit will assume "mustExpand = true" (CORINFO_FLG_MUSTEXPAND_INTRINSIC).

Based on the current logic in the VM (https://github.com/dotnet/coreclr/blob/master/src/vm/methodtablebuilder.cpp#L5144) this bit is set for any method marked with [Intrinsic]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, 'bit' was a poor choice of words. We added this bit of logic in the importer just for HW intrinsics.

Not all [Intrinsic] methods are must expand; some of them have perfectly viable IL implementations.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I get it 😄

Then my question is: Do we expect all intrinsics for which gtIsRecursiveCall() would return true to always expand (except for indirect invocation) or do we only expect them to always expand for hardware intrinsics?

Even in the case of the former, I'm not sure how we detect the difference between a JitIntrinsic that is recursive and one that has an IL implementation on first pass..


if (intrinsicID == CORINFO_INTRINSIC_Illegal)
{
ni = lookupNamedIntrinsic(method);

#if FEATURE_HW_INTRINSICS
#ifdef _TARGET_XARCH_
if (ni > NI_HW_INTRINSIC_START && ni < NI_HW_INTRINSIC_END)
{
return impX86HWIntrinsic(ni, method, sig);
}
#endif // _TARGET_XARCH_
#endif // FEATURE_HW_INTRINSICS
}
}

*pIntrinsicID = intrinsicID;
Expand Down Expand Up @@ -3875,16 +3890,9 @@ GenTree* Compiler::impIntrinsic(GenTree* newobjThis,
}

// Look for new-style jit intrinsics by name
if ((intrinsicID == CORINFO_INTRINSIC_Illegal) && ((methodFlags & CORINFO_FLG_JIT_INTRINSIC) != 0))
if (ni != NI_Illegal)
{
assert(retNode == nullptr);
const NamedIntrinsic ni = lookupNamedIntrinsic(method);
#if FEATURE_HW_INTRINSICS && defined(_TARGET_XARCH_)
if (ni > NI_HW_INTRINSIC_START && ni < NI_HW_INTRINSIC_END)
{
return impX86HWIntrinsic(ni, method, sig);
}
#endif // FEATURE_HW_INTRINSICS
switch (ni)
{
case NI_System_Enum_HasFlag:
Expand Down
36 changes: 32 additions & 4 deletions tests/src/JIT/HardwareIntrinsics/X86/Avx/Add.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,53 @@ static unsafe int Main(string[] args)
var vf3 = Avx.Add(vf1, vf2);
Unsafe.Write(floatTable.outArrayPtr, vf3);

if (!floatTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX Add failed on float:");
foreach (var item in floatTable.outArray)
{
Console.Write(item + ", ");
}
Console.WriteLine();
testResult = Fail;
}

vf3 = (Vector256<float>)typeof(Avx).GetMethod(nameof(Avx.Add), new Type[] { vf1.GetType(), vf2.GetType() }).Invoke(null, new object[] { vf1, vf2 });
Unsafe.Write(floatTable.outArrayPtr, vf3);

if (!floatTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX Add failed via reflection on float:");
foreach (var item in floatTable.outArray)
{
Console.Write(item + ", ");
}
Console.WriteLine();
testResult = Fail;
}

var vd1 = Unsafe.Read<Vector256<double>>(doubleTable.inArray1Ptr);
var vd2 = Unsafe.Read<Vector256<double>>(doubleTable.inArray2Ptr);
var vd3 = Avx.Add(vd1, vd2);
Unsafe.Write(doubleTable.outArrayPtr, vd3);

if (!floatTable.CheckResult((x, y, z) => x + y == z))
if (!doubleTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX Add failed on float:");
foreach (var item in floatTable.outArray)
Console.WriteLine("AVX Add failed on double:");
foreach (var item in doubleTable.outArray)
{
Console.Write(item + ", ");
}
Console.WriteLine();
testResult = Fail;
}

vd3 = (Vector256<double>)typeof(Avx).GetMethod(nameof(Avx.Add), new Type[] { vd1.GetType(), vd2.GetType() }).Invoke(null, new object[] { vd1, vd2 });
Unsafe.Write(doubleTable.outArrayPtr, vd3);

if (!doubleTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX Add failed on double:");
Console.WriteLine("AVX Add failed via reflection on double:");
foreach (var item in doubleTable.outArray)
{
Console.Write(item + ", ");
Expand Down
181 changes: 146 additions & 35 deletions tests/src/JIT/HardwareIntrinsics/X86/Avx2/Add.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,44 +37,23 @@ static unsafe int Main(string[] args)
var vi3 = Avx2.Add(vi1, vi2);
Unsafe.Write(intTable.outArrayPtr, vi3);

var vl1 = Unsafe.Read<Vector256<long>>(longTable.inArray1Ptr);
var vl2 = Unsafe.Read<Vector256<long>>(longTable.inArray2Ptr);
var vl3 = Avx2.Add(vl1, vl2);
Unsafe.Write(longTable.outArrayPtr, vl3);

var vui1 = Unsafe.Read<Vector256<uint>>(uintTable.inArray1Ptr);
var vui2 = Unsafe.Read<Vector256<uint>>(uintTable.inArray2Ptr);
var vui3 = Avx2.Add(vui1, vui2);
Unsafe.Write(uintTable.outArrayPtr, vui3);

var vul1 = Unsafe.Read<Vector256<ulong>>(ulongTable.inArray1Ptr);
var vul2 = Unsafe.Read<Vector256<ulong>>(ulongTable.inArray2Ptr);
var vul3 = Avx2.Add(vul1, vul2);
Unsafe.Write(ulongTable.outArrayPtr, vul3);

var vs1 = Unsafe.Read<Vector256<short>>(shortTable.inArray1Ptr);
var vs2 = Unsafe.Read<Vector256<short>>(shortTable.inArray2Ptr);
var vs3 = Avx2.Add(vs1, vs2);
Unsafe.Write(shortTable.outArrayPtr, vs3);

var vus1 = Unsafe.Read<Vector256<ushort>>(ushortTable.inArray1Ptr);
var vus2 = Unsafe.Read<Vector256<ushort>>(ushortTable.inArray2Ptr);
var vus3 = Avx2.Add(vus1, vus2);
Unsafe.Write(ushortTable.outArrayPtr, vus3);

var vsb1 = Unsafe.Read<Vector256<sbyte>>(sbyteTable.inArray1Ptr);
var vsb2 = Unsafe.Read<Vector256<sbyte>>(sbyteTable.inArray2Ptr);
var vsb3 = Avx2.Add(vsb1, vsb2);
Unsafe.Write(sbyteTable.outArrayPtr, vsb3);
if (!intTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed on int:");
foreach (var item in intTable.outArray)
{
Console.Write(item + ", ");
}
Console.WriteLine();
testResult = Fail;
}

var vb1 = Unsafe.Read<Vector256<byte>>(byteTable.inArray1Ptr);
var vb2 = Unsafe.Read<Vector256<byte>>(byteTable.inArray2Ptr);
var vb3 = Avx2.Add(vb1, vb2);
Unsafe.Write(byteTable.outArrayPtr, vb3);
vi3 = (Vector256<int>)typeof(Avx2).GetMethod(nameof(Avx2.Add), new Type[] { vi1.GetType(), vi2.GetType() }).Invoke(null, new object[] { vi1, vi2 });
Unsafe.Write(intTable.outArrayPtr, vi3);

if (!intTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed on int:");
Console.WriteLine("AVX2 Add failed via reflection on int:");
foreach (var item in intTable.outArray)
{
Console.Write(item + ", ");
Expand All @@ -83,6 +62,11 @@ static unsafe int Main(string[] args)
testResult = Fail;
}

var vl1 = Unsafe.Read<Vector256<long>>(longTable.inArray1Ptr);
var vl2 = Unsafe.Read<Vector256<long>>(longTable.inArray2Ptr);
var vl3 = Avx2.Add(vl1, vl2);
Unsafe.Write(longTable.outArrayPtr, vl3);

if (!longTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed on long:");
Expand All @@ -94,6 +78,25 @@ static unsafe int Main(string[] args)
testResult = Fail;
}

vl3 = (Vector256<long>)typeof(Avx2).GetMethod(nameof(Avx2.Add), new Type[] { vl1.GetType(), vl2.GetType() }).Invoke(null, new object[] { vl1, vl2 });
Unsafe.Write(longTable.outArrayPtr, vl3);

if (!longTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed via reflection on long:");
foreach (var item in longTable.outArray)
{
Console.Write(item + ", ");
}
Console.WriteLine();
testResult = Fail;
}

var vui1 = Unsafe.Read<Vector256<uint>>(uintTable.inArray1Ptr);
var vui2 = Unsafe.Read<Vector256<uint>>(uintTable.inArray2Ptr);
var vui3 = Avx2.Add(vui1, vui2);
Unsafe.Write(uintTable.outArrayPtr, vui3);

if (!uintTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed on uint:");
Expand All @@ -105,6 +108,25 @@ static unsafe int Main(string[] args)
testResult = Fail;
}

vui3 = (Vector256<uint>)typeof(Avx2).GetMethod(nameof(Avx2.Add), new Type[] { vui1.GetType(), vui2.GetType() }).Invoke(null, new object[] { vui1, vui2 });
Unsafe.Write(uintTable.outArrayPtr, vui3);

if (!uintTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed via reflection on uint:");
foreach (var item in uintTable.outArray)
{
Console.Write(item + ", ");
}
Console.WriteLine();
testResult = Fail;
}

var vul1 = Unsafe.Read<Vector256<ulong>>(ulongTable.inArray1Ptr);
var vul2 = Unsafe.Read<Vector256<ulong>>(ulongTable.inArray2Ptr);
var vul3 = Avx2.Add(vul1, vul2);
Unsafe.Write(ulongTable.outArrayPtr, vul3);

if (!ulongTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed on ulong:");
Expand All @@ -116,6 +138,25 @@ static unsafe int Main(string[] args)
testResult = Fail;
}

vul3 = (Vector256<ulong>)typeof(Avx2).GetMethod(nameof(Avx2.Add), new Type[] { vul1.GetType(), vul2.GetType() }).Invoke(null, new object[] { vul1, vul2 });
Unsafe.Write(ulongTable.outArrayPtr, vul3);

if (!ulongTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed via reflection on ulong:");
foreach (var item in ulongTable.outArray)
{
Console.Write(item + ", ");
}
Console.WriteLine();
testResult = Fail;
}

var vs1 = Unsafe.Read<Vector256<short>>(shortTable.inArray1Ptr);
var vs2 = Unsafe.Read<Vector256<short>>(shortTable.inArray2Ptr);
var vs3 = Avx2.Add(vs1, vs2);
Unsafe.Write(shortTable.outArrayPtr, vs3);

if (!shortTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed on short:");
Expand All @@ -127,6 +168,25 @@ static unsafe int Main(string[] args)
testResult = Fail;
}

vs3 = (Vector256<short>)typeof(Avx2).GetMethod(nameof(Avx2.Add), new Type[] { vs1.GetType(), vs2.GetType() }).Invoke(null, new object[] { vs1, vs2 });
Unsafe.Write(shortTable.outArrayPtr, vs3);

if (!shortTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed via reflection on short:");
foreach (var item in shortTable.outArray)
{
Console.Write(item + ", ");
}
Console.WriteLine();
testResult = Fail;
}

var vus1 = Unsafe.Read<Vector256<ushort>>(ushortTable.inArray1Ptr);
var vus2 = Unsafe.Read<Vector256<ushort>>(ushortTable.inArray2Ptr);
var vus3 = Avx2.Add(vus1, vus2);
Unsafe.Write(ushortTable.outArrayPtr, vus3);

if (!ushortTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed on ushort:");
Expand All @@ -138,6 +198,25 @@ static unsafe int Main(string[] args)
testResult = Fail;
}

vus3 = (Vector256<ushort>)typeof(Avx2).GetMethod(nameof(Avx2.Add), new Type[] { vus1.GetType(), vus2.GetType() }).Invoke(null, new object[] { vus1, vus2 });
Unsafe.Write(ushortTable.outArrayPtr, vus3);

if (!ushortTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed via reflection on ushort:");
foreach (var item in ushortTable.outArray)
{
Console.Write(item + ", ");
}
Console.WriteLine();
testResult = Fail;
}

var vsb1 = Unsafe.Read<Vector256<sbyte>>(sbyteTable.inArray1Ptr);
var vsb2 = Unsafe.Read<Vector256<sbyte>>(sbyteTable.inArray2Ptr);
var vsb3 = Avx2.Add(vsb1, vsb2);
Unsafe.Write(sbyteTable.outArrayPtr, vsb3);

if (!sbyteTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed on sbyte:");
Expand All @@ -149,6 +228,25 @@ static unsafe int Main(string[] args)
testResult = Fail;
}

vsb3 = (Vector256<sbyte>)typeof(Avx2).GetMethod(nameof(Avx2.Add), new Type[] { vsb1.GetType(), vsb2.GetType() }).Invoke(null, new object[] { vsb1, vsb2 });
Unsafe.Write(sbyteTable.outArrayPtr, vsb3);

if (!sbyteTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed via reflection on sbyte:");
foreach (var item in sbyteTable.outArray)
{
Console.Write(item + ", ");
}
Console.WriteLine();
testResult = Fail;
}

var vb1 = Unsafe.Read<Vector256<byte>>(byteTable.inArray1Ptr);
var vb2 = Unsafe.Read<Vector256<byte>>(byteTable.inArray2Ptr);
var vb3 = Avx2.Add(vb1, vb2);
Unsafe.Write(byteTable.outArrayPtr, vb3);

if (!byteTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed on byte:");
Expand All @@ -159,8 +257,21 @@ static unsafe int Main(string[] args)
Console.WriteLine();
testResult = Fail;
}
}

vb3 = (Vector256<byte>)typeof(Avx2).GetMethod(nameof(Avx2.Add), new Type[] { vb1.GetType(), vb2.GetType() }).Invoke(null, new object[] { vb1, vb2 });
Unsafe.Write(byteTable.outArrayPtr, vb3);

if (!byteTable.CheckResult((x, y, z) => x + y == z))
{
Console.WriteLine("AVX2 Add failed via reflection on byte:");
foreach (var item in byteTable.outArray)
{
Console.Write(item + ", ");
}
Console.WriteLine();
testResult = Fail;
}
}
}

return testResult;
Expand Down
Loading