diff --git a/src/coreclr/jit/codegenxarch.cpp b/src/coreclr/jit/codegenxarch.cpp index f495acf2c07e67..837937d12db79d 100644 --- a/src/coreclr/jit/codegenxarch.cpp +++ b/src/coreclr/jit/codegenxarch.cpp @@ -4403,6 +4403,34 @@ void CodeGen::genCodeForShift(GenTree* tree) inst_RV_SH(ins, size, tree->GetRegNum(), shiftByValue); } } +#if defined(TARGET_64BIT) + else if (tree->OperIsShift() && compiler->compOpportunisticallyDependsOn(InstructionSet_BMI2)) + { + // Try to emit shlx, sarx, shrx if BMI2 is available instead of mov+shl, mov+sar, mov+shr. + switch (tree->OperGet()) + { + case GT_LSH: + ins = INS_shlx; + break; + + case GT_RSH: + ins = INS_sarx; + break; + + case GT_RSZ: + ins = INS_shrx; + break; + + default: + unreached(); + } + + regNumber shiftByReg = shiftBy->GetRegNum(); + emitAttr size = emitTypeSize(tree); + // The order of operandReg and shiftByReg are swapped to follow shlx, sarx and shrx encoding spec. + GetEmitter()->emitIns_R_R_R(ins, size, tree->GetRegNum(), shiftByReg, operandReg); + } +#endif else { // We must have the number of bits to shift stored in ECX, since we constrained this node to diff --git a/src/coreclr/jit/emitxarch.cpp b/src/coreclr/jit/emitxarch.cpp index ebe6030a8d7093..7ca24ecf672675 100644 --- a/src/coreclr/jit/emitxarch.cpp +++ b/src/coreclr/jit/emitxarch.cpp @@ -749,6 +749,9 @@ bool emitter::TakesRexWPrefix(instruction ins, emitAttr attr) case INS_pdep: case INS_pext: case INS_rorx: + case INS_shlx: + case INS_sarx: + case INS_shrx: return true; default: return false; @@ -987,17 +990,32 @@ unsigned emitter::emitOutputRexOrVexPrefixIfNeeded(instruction ins, BYTE* dst, c case INS_rorx: case INS_pdep: case INS_mulx: +// TODO: Unblock when enabled for x86 +#ifdef TARGET_AMD64 + case INS_shrx: +#endif { vexPrefix |= 0x03; break; } case INS_pext: +// TODO: Unblock when enabled for x86 +#ifdef TARGET_AMD64 + case INS_sarx: +#endif { vexPrefix |= 0x02; break; } - +// TODO: Unblock when enabled for x86 +#ifdef TARGET_AMD64 + case INS_shlx: + { + vexPrefix |= 0x01; + break; + } +#endif default: { vexPrefix |= 0x00; @@ -1484,6 +1502,11 @@ bool emitter::emitInsCanOnlyWriteSSE2OrAVXReg(instrDesc* id) case INS_pextrw: case INS_pextrw_sse41: case INS_rorx: +#ifdef TARGET_AMD64 + case INS_shlx: + case INS_sarx: + case INS_shrx: +#endif { // These SSE instructions write to a general purpose integer register. return false; @@ -9519,9 +9542,13 @@ void emitter::emitDispIns( assert(IsThreeOperandAVXInstruction(ins)); regNumber reg2 = id->idReg2(); regNumber reg3 = id->idReg3(); - if (ins == INS_bextr || ins == INS_bzhi) + if (ins == INS_bextr || ins == INS_bzhi +#ifdef TARGET_AMD64 + || ins == INS_shrx || ins == INS_shlx || ins == INS_sarx +#endif + ) { - // BMI bextr and bzhi encodes the reg2 in VEX.vvvv and reg3 in modRM, + // BMI bextr,bzhi, shrx, shlx and sarx encode the reg2 in VEX.vvvv and reg3 in modRM, // which is different from most of other instructions regNumber tmp = reg2; reg2 = reg3; @@ -16323,6 +16350,16 @@ emitter::insExecutionCharacteristics emitter::getInsExecutionCharacteristics(ins break; } +#ifdef TARGET_AMD64 + case INS_shlx: + case INS_sarx: + case INS_shrx: + { + result.insLatency += PERFSCORE_LATENCY_1C; + result.insThroughput = PERFSCORE_THROUGHPUT_2X; + break; + } +#endif default: // unhandled instruction insFmt combination perfScoreUnhandledInstruction(id, &result); diff --git a/src/coreclr/jit/instrsxarch.h b/src/coreclr/jit/instrsxarch.h index 5a626ce26e96ec..0a8527a6393fe5 100644 --- a/src/coreclr/jit/instrsxarch.h +++ b/src/coreclr/jit/instrsxarch.h @@ -605,6 +605,11 @@ INST3(pdep, "pdep", IUM_WR, BAD_CODE, BAD_CODE, INST3(pext, "pext", IUM_WR, BAD_CODE, BAD_CODE, SSE38(0xF5), INS_Flags_IsDstDstSrcAVXInstruction) // Parallel Bits Extract INST3(bzhi, "bzhi", IUM_WR, BAD_CODE, BAD_CODE, SSE38(0xF5), Resets_OF | Writes_SF | Writes_ZF | Undefined_AF | Undefined_PF | Writes_CF | INS_Flags_IsDstDstSrcAVXInstruction) // Zero High Bits Starting with Specified Bit Position INST3(mulx, "mulx", IUM_WR, BAD_CODE, BAD_CODE, SSE38(0xF6), INS_Flags_IsDstDstSrcAVXInstruction) // Unsigned Multiply Without Affecting Flags +#ifdef TARGET_AMD64 +INST3(shlx, "shlx", IUM_WR, BAD_CODE, BAD_CODE, SSE38(0xF7), INS_Flags_IsDstDstSrcAVXInstruction) // Shift Logical Left Without Affecting Flags +INST3(sarx, "sarx", IUM_WR, BAD_CODE, BAD_CODE, PACK4(0xF3, 0x0F, 0x38, 0xF7), INS_Flags_IsDstDstSrcAVXInstruction) // Shift Arithmetic Right Without Affecting Flags +INST3(shrx, "shrx", IUM_WR, BAD_CODE, BAD_CODE, PACK4(0xF2, 0x0F, 0x38, 0xF7), INS_Flags_IsDstDstSrcAVXInstruction) // Shift Logical Right Without Affecting Flags +#endif INST3(LAST_BMI_INSTRUCTION, "LAST_BMI_INSTRUCTION", IUM_WR, BAD_CODE, BAD_CODE, BAD_CODE, INS_FLAGS_None) diff --git a/src/coreclr/jit/lowerxarch.cpp b/src/coreclr/jit/lowerxarch.cpp index ee94c6d9042ddf..65124bfc35533a 100644 --- a/src/coreclr/jit/lowerxarch.cpp +++ b/src/coreclr/jit/lowerxarch.cpp @@ -4850,7 +4850,7 @@ void Lowering::ContainCheckShiftRotate(GenTreeOp* node) assert(source->OperGet() == GT_LONG); MakeSrcContained(node, source); } -#endif // !TARGET_X86 +#endif GenTree* shiftBy = node->gtOp2; if (IsContainableImmed(node, shiftBy) && (shiftBy->AsIntConCommon()->IconValue() <= 255) && diff --git a/src/coreclr/jit/lsraxarch.cpp b/src/coreclr/jit/lsraxarch.cpp index 4f7cde2cc62f65..1e5f03463c8079 100644 --- a/src/coreclr/jit/lsraxarch.cpp +++ b/src/coreclr/jit/lsraxarch.cpp @@ -925,6 +925,18 @@ int LinearScan::BuildShiftRotate(GenTree* tree) { assert(shiftBy->OperIsConst()); } +#if defined(TARGET_64BIT) + else if (tree->OperIsShift() && !tree->isContained() && + compiler->compOpportunisticallyDependsOn(InstructionSet_BMI2)) + { + // shlx (as opposed to mov+shl) instructions handles all register forms, but it does not handle contained form + // for memory operand. Likewise for sarx and shrx. + srcCount += BuildOperandUses(source, srcCandidates); + srcCount += BuildOperandUses(shiftBy, srcCandidates); + BuildDef(tree, dstCandidates); + return srcCount; + } +#endif else { srcCandidates = allRegs(TYP_INT) & ~RBM_RCX; diff --git a/src/tests/JIT/SIMD/ShiftOperations.cs b/src/tests/JIT/SIMD/ShiftOperations.cs new file mode 100644 index 00000000000000..5122d261bdedac --- /dev/null +++ b/src/tests/JIT/SIMD/ShiftOperations.cs @@ -0,0 +1,314 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +using System; +using System.Runtime.CompilerServices; +using System.Numerics; +using System.Runtime.InteropServices; +using System.Collections.Generic; + +public class Test +{ + [MethodImpl(MethodImplOptions.NoInlining)] + private static R Shlx64bit(T x, int y) + { + switch (x) + { + case ulong a: + ulong resUlong = ((ulong)a) << y; + return (R)Convert.ChangeType(resUlong, typeof(R)); + case uint b: + uint resUint = ((uint)b) << y; + return (R)Convert.ChangeType(resUint, typeof(R)); + case ushort c: + int resInt = ((ushort)c) << y; + return (R)Convert.ChangeType(resInt, typeof(R)); + default: + Console.WriteLine("Unsupported type."); + return default(R); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static R Sarx64bit(T x, int y) + { + int resInt = 0; + switch (x) + { + case long a: + long resLong = ((long)a) >> y; + return (R)Convert.ChangeType(resLong, typeof(R)); + case int b: + resInt = ((int)b) >> y; + return (R)Convert.ChangeType(resInt, typeof(R)); + case short c: + Console.WriteLine($"Before: {Convert.ToString((short)c, toBase: 2)}"); + resInt = ((short)c) >> y; + Console.WriteLine($"After: {Convert.ToString(resInt, toBase: 2)}"); + return (R)Convert.ChangeType(resInt, typeof(R)); + default: + Console.WriteLine("Unsupported type."); + return default(R); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static R Shrx64bit(T x, int y) + { + switch (x) + { + case ulong a: + ulong resUlong = ((ulong)a) >> y; + return (R)Convert.ChangeType(resUlong, typeof(R)); + case uint b: + uint resUint = ((uint)b) >> y; + return (R)Convert.ChangeType(resUint, typeof(R)); + case ushort c: + int resInt = ((ushort)c) >> y; + return (R)Convert.ChangeType(resInt, typeof(R)); + default: + Console.WriteLine("Unsupported type."); + return default(R); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static unsafe ulong ShrxRef64bit(ulong* x, int y) => *x >> y; + + [MethodImpl(MethodImplOptions.NoInlining)] + private static unsafe uint ShrxRef64bit(uint* x, int y) => *x >> y; + + [MethodImpl(MethodImplOptions.NoInlining)] + private static unsafe int ShrxRef64bit(ushort* x, int y) => *x >> y; + + public static unsafe int Main() + { + const int PASS = 100; + const int FAIL = 101; + int returnCode = PASS; + + try + { + // + // Shlx64bit tests + // + + // ulong + int MOD64 = 64; + + Console.WriteLine(); + Console.WriteLine("### UnitTest: Shlx64bit (ulong) ###############"); + ulong[] valUlong = new ulong[] { 0, 8, 1, 1, 0xFFFFFFFFFFFFFFFF }; + int[] shiftBy = new int[] { 1, 1, 63, 65, 1 }; + for (int idx = 0; idx < valUlong.Length; idx++) + { + ulong resULong = (ulong)Shlx64bit(valUlong[idx], shiftBy[idx]); + ulong expectedUlong = (ulong)(valUlong[idx] << (shiftBy[idx] % MOD64)); + if (!Validate(valUlong[idx], shiftBy[idx], resULong, expectedUlong)) + { + returnCode = FAIL; + } + } + + // uint + int MOD32 = 32; + + Console.WriteLine(); + Console.WriteLine("### UnitTest: Shlx64bit (uint) ###############"); + uint[] valUint = new uint[] { 0, 8, 1, 1, 0xFFFFFFFF }; + shiftBy = new int[] { 1, 1, 32, 33, 1 }; + for (int idx = 0; idx < valUint.Length; idx++) + { + uint resUint = (uint)Shlx64bit(valUint[idx], shiftBy[idx]); + uint expectedUint = (uint)(valUint[idx] << (shiftBy[idx] % MOD32)); + if (!Validate(valUint[idx], shiftBy[idx], resUint, expectedUint)) + { + returnCode = FAIL; + } + } + + // ushort + Console.WriteLine(); + Console.WriteLine("### UnitTest: Shlx64bit (ushort) ###############"); + ushort[] valUshort = new ushort[] { 0, 8, 1, 1, 0b_0111_0001_1000_0010 }; + shiftBy = new int[] { 1, 1, 16, 18, 16 }; + for (int idx = 0; idx < valUshort.Length; idx++) + { + int resInt = (int)Shlx64bit(valUshort[idx], shiftBy[idx]); + int expectedInt = (int)(((int)valUshort[idx]) << (shiftBy[idx] % MOD32)); + if (!Validate(valUshort[idx], shiftBy[idx], resInt, expectedInt)) + { + returnCode = FAIL; + } + } + + // + // Sarx64bit tests + // + + // long + Console.WriteLine(); + Console.WriteLine("### UnitTest: Sarx64bit (long) ###############"); + long[] valLong = new long[] { 1, -8, -8, 0x7FFFFFFFFFFFFFFF }; + shiftBy = new int[] { 1, 1, 65, 63 }; + for (int idx = 0; idx < valLong.Length; idx++) + { + long resLong = (long)Sarx64bit(valLong[idx], shiftBy[idx]); + long expectedLong = (long)(valLong[idx] >> (shiftBy[idx] % MOD64)); + if (!Validate(valLong[idx], shiftBy[idx], resLong, expectedLong)) + { + returnCode = FAIL; + } + } + + // int + Console.WriteLine(); + Console.WriteLine("### UnitTest: Sarx64bit (int) ###############"); + int[] valInt = new int[] { 1, -8, -8, 0x7FFFFFFF }; + shiftBy = new int[] { 1, 1, 32, 33 }; + for (int idx = 0; idx < valInt.Length; idx++) + { + int resInt = (int)Sarx64bit(valInt[idx], shiftBy[idx]); + int expectedInt = (int)(valInt[idx] >> (shiftBy[idx] % MOD32)); + if (!Validate(valInt[idx], shiftBy[idx], resInt, expectedInt)) + { + returnCode = FAIL; + } + } + + // short + Console.WriteLine(); + Console.WriteLine("### UnitTest: Sarx64bit (short) ###############"); + short[] valShort = new short[] { 1, -8, -8, 0b_0111_0001_1000_0010 }; + shiftBy = new int[] { 1, 1, 16, 18 }; + for (int idx = 0; idx < valShort.Length; idx++) + { + int resInt = (int)Sarx64bit(valShort[idx], shiftBy[idx]); + int expectedInt = (int)valShort[idx] >> (shiftBy[idx] % MOD32); + if (!Validate(valShort[idx], shiftBy[idx], resInt, expectedInt)) + { + returnCode = FAIL; + } + } + + // + // Shrx64bit tests + // + + // ulong + Console.WriteLine(); + Console.WriteLine("### UnitTest: Shrx64bit (ulong) ###############"); + valUlong = new ulong[] { 1, 8, 8, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF }; + shiftBy = new int[] { 1, 2, 65, 63, 65 }; + for (int idx = 0; idx < valUlong.Length; idx++) + { + ulong resULong = (ulong)Shrx64bit(valUlong[idx], shiftBy[idx]); + ulong expectedUlong = (ulong)(valUlong[idx] >> (shiftBy[idx] % MOD64)); + if (!Validate(valUlong[idx], shiftBy[idx], resULong, expectedUlong)) + { + returnCode = FAIL; + } + } + + // uint + Console.WriteLine(); + Console.WriteLine("### UnitTest: Shrx64bit (uint) ###############"); + valUint = new uint[] { 1, 8, 8, 0xFFFFFFFF }; + shiftBy = new int[] { 1, 1, 32, 33 }; + for (int idx = 0; idx < valUint.Length; idx++) + { + uint resUint = (uint)Shrx64bit(valUint[idx], shiftBy[idx]); + uint expectedUint = (uint)(valUint[idx] >> (shiftBy[idx] % MOD32)); + if (!Validate(valUint[idx], shiftBy[idx], resUint, expectedUint)) + { + returnCode = FAIL; + } + } + + // ushort + Console.WriteLine(); + Console.WriteLine("### UnitTest: Shrx64bit (ushort) ###############"); + valUshort = new ushort[] { 0, 8, 0b_1000_0000_0000_0000, 0b_1000_0000_0000_0000, 0b_1111_0001_1000_0010 }; + shiftBy = new int[] { 1, 1, 15, 18, 40 }; + for (int idx = 0; idx < valUshort.Length; idx++) + { + int resInt = (int)Shrx64bit(valUshort[idx], shiftBy[idx]); + int expectedInt = (int)(((int)valUshort[idx]) >> (shiftBy[idx] % MOD32)); + if (!Validate(valUshort[idx], shiftBy[idx], resInt, expectedInt)) + { + returnCode = FAIL; + } + } + + // + // ShrxRef64bit + // + + // ulong + Console.WriteLine(); + Console.WriteLine("### UnitTest: ShrxRef64bit (ulong) ###############"); + ulong valUlongRef = 8; + int shiftByRef = 1; + ulong resUlongRef = ShrxRef64bit(&valUlongRef, shiftByRef); + ulong expectedULongRef = (ulong)(valUlongRef >> (shiftByRef % MOD64)); + if (!Validate(valUlongRef, shiftByRef, resUlongRef, expectedULongRef)) + { + returnCode = FAIL; + } + + // uint + Console.WriteLine(); + Console.WriteLine("### UnitTest: ShrxRef64bit (uint) ###############"); + uint valUintRef = 0xFFFFFFFF; + shiftByRef = 1; + uint resUintRef = ShrxRef64bit(&valUintRef, shiftByRef); + uint expectedUintRef = (uint)(valUintRef >> (shiftByRef % MOD32)); + if (!Validate(valUintRef, shiftByRef, resUintRef, expectedUintRef)) + { + returnCode = FAIL; + } + + // ushort + Console.WriteLine(); + Console.WriteLine("### UnitTest: ShrxRef64bit (ushort) ###############"); + ushort valUshortRef = 0xFFFF; + shiftByRef = 15; + int resUshortRef = ShrxRef64bit(&valUshortRef, shiftByRef); + int expectedUshortRef = (int)((uint)valUshortRef >> (shiftByRef % MOD32)); + if (!Validate(valUshortRef, shiftByRef, resUshortRef, expectedUshortRef)) + { + returnCode = FAIL; + } + } + catch (Exception e) + { + Console.WriteLine(e.Message); + return FAIL; + } + + Console.WriteLine(); + if (returnCode == PASS) + { + Console.WriteLine("PASSED."); + } + else + { + Console.WriteLine("FAILED."); + } + return returnCode; + } + + private static bool Validate(TValue value, int shiftBy, TResult actual, TResult expected) + { + Console.Write("(value, shiftBy) ({0},{1}): {2}", value, shiftBy, actual); + if (EqualityComparer.Default.Equals(actual, expected)) + { + Console.WriteLine(" == {0} ==> Passed.", expected); + return true; + } + else + { + Console.WriteLine(" != {0} ==> Failed.", expected); + return false; + } + } +} diff --git a/src/tests/JIT/SIMD/ShiftOperations.csproj b/src/tests/JIT/SIMD/ShiftOperations.csproj new file mode 100644 index 00000000000000..d7141b8f4b1601 --- /dev/null +++ b/src/tests/JIT/SIMD/ShiftOperations.csproj @@ -0,0 +1,13 @@ + + + Exe + true + + + PdbOnly + True + + + + + diff --git a/src/tests/issues.targets b/src/tests/issues.targets index 97b8df0f720d7b..458304bc17832f 100644 --- a/src/tests/issues.targets +++ b/src/tests/issues.targets @@ -1484,7 +1484,10 @@ https://github.com/dotnet/runtime/issues/46174 - + + There is a known undefined behavior with shifts and 0xFFFFFFFF overflows, so skip the test for mono. + + Tests features specific to coreclr