diff --git a/src/jit/codegen.h b/src/jit/codegen.h index c6e38ab6af60..090283ee50e8 100755 --- a/src/jit/codegen.h +++ b/src/jit/codegen.h @@ -390,6 +390,8 @@ class CodeGen : public CodeGenInterface // Save/Restore callee saved float regs to stack void genPreserveCalleeSavedFltRegs(unsigned lclFrameSize); void genRestoreCalleeSavedFltRegs(unsigned lclFrameSize); + // Generate VZeroupper instruction to avoid AVX/SSE transition penalty + void genVzeroupperIfNeeded(bool check256bitOnly = true); #endif // _TARGET_XARCH_ && FEATURE_STACK_FP_X87 diff --git a/src/jit/codegencommon.cpp b/src/jit/codegencommon.cpp index 240911523f4d..f42103ebce71 100644 --- a/src/jit/codegencommon.cpp +++ b/src/jit/codegencommon.cpp @@ -10583,6 +10583,7 @@ GenTreePtr CodeGen::genMakeConst(const void* cnsAddr, var_types cnsType, GenTree // funclet frames: this will be FuncletInfo.fiSpDelta. void CodeGen::genPreserveCalleeSavedFltRegs(unsigned lclFrameSize) { + genVzeroupperIfNeeded(false); regMaskTP regMask = compiler->compCalleeFPRegsSavedMask; // Only callee saved floating point registers should be in regMask @@ -10621,16 +10622,6 @@ void CodeGen::genPreserveCalleeSavedFltRegs(unsigned lclFrameSize) offset -= XMM_REGSIZE_BYTES; } } - -#ifdef FEATURE_AVX_SUPPORT - // Just before restoring float registers issue a Vzeroupper to zero out upper 128-bits of all YMM regs. - // This is to avoid penalty if this routine is using AVX-256 and now returning to a routine that is - // using SSE2. - if (compiler->getFloatingPointInstructionSet() == InstructionSet_AVX) - { - instGen(INS_vzeroupper); - } -#endif } // Save/Restore compCalleeFPRegsPushed with the smallest register number saved at [RSP+offset], working @@ -10651,6 +10642,7 @@ void CodeGen::genRestoreCalleeSavedFltRegs(unsigned lclFrameSize) // fast path return if (regMask == RBM_NONE) { + genVzeroupperIfNeeded(); return; } @@ -10682,16 +10674,6 @@ void CodeGen::genRestoreCalleeSavedFltRegs(unsigned lclFrameSize) assert((offset % 16) == 0); #endif // _TARGET_AMD64_ -#ifdef FEATURE_AVX_SUPPORT - // Just before restoring float registers issue a Vzeroupper to zero out upper 128-bits of all YMM regs. - // This is to avoid penalty if this routine is using AVX-256 and now returning to a routine that is - // using SSE2. - if (compiler->getFloatingPointInstructionSet() == InstructionSet_AVX) - { - instGen(INS_vzeroupper); - } -#endif - for (regNumber reg = REG_FLT_CALLEE_SAVED_FIRST; regMask != RBM_NONE; reg = REG_NEXT(reg)) { regMaskTP regBit = genRegMask(reg); @@ -10706,7 +10688,41 @@ void CodeGen::genRestoreCalleeSavedFltRegs(unsigned lclFrameSize) offset -= XMM_REGSIZE_BYTES; } } + genVzeroupperIfNeeded(); +} + +// Generate Vzeroupper instruction as needed to zero out upper 128b-bit of all YMM registers so that the +// AVX/Legacy SSE transition penalties can be avoided. This function is been used in genPreserveCalleeSavedFltRegs +// (prolog) and genRestoreCalleeSavedFltRegs (epilog). Issue VZEROUPPER in Prolog if the method contains +// 128-bit or 256-bit AVX code, to avoid legacy SSE to AVX transition penalty, which could happen when native +// code contains legacy SSE code calling into JIT AVX code (e.g. reverse pinvoke). Issue VZEROUPPER in Epilog +// if the method contains 256-bit AVX code, to avoid AVX to legacy SSE transition penalty. +// +// Params +// check256bitOnly - true to check if the function contains 256-bit AVX instruction and generate Vzeroupper +// instruction, false to check if the function contains AVX instruciton (either 128-bit or 256-bit). +// +void CodeGen::genVzeroupperIfNeeded(bool check256bitOnly /* = true*/) +{ +#ifdef FEATURE_AVX_SUPPORT + bool emitVzeroUpper = false; + if (check256bitOnly) + { + emitVzeroUpper = getEmitter()->Contains256bitAVX(); + } + else + { + emitVzeroUpper = getEmitter()->ContainsAVX(); + } + + if (emitVzeroUpper) + { + assert(compiler->getSIMDInstructionSet() == InstructionSet_AVX); + instGen(INS_vzeroupper); + } +#endif } + #endif // defined(_TARGET_XARCH_) && !FEATURE_STACK_FP_X87 //----------------------------------------------------------------------------------- diff --git a/src/jit/codegenxarch.cpp b/src/jit/codegenxarch.cpp index 8e0af48799ab..3241c8833872 100644 --- a/src/jit/codegenxarch.cpp +++ b/src/jit/codegenxarch.cpp @@ -5001,6 +5001,20 @@ void CodeGen::genCallInstruction(GenTreePtr node) #endif // defined(_TARGET_X86_) +#ifdef FEATURE_AVX_SUPPORT + // When it's a PInvoke call and the call type is USER function, we issue VZEROUPPER here + // if the function contains 256bit AVX instructions, this is to avoid AVX-256 to Legacy SSE + // transition penalty, assuming the user function contains legacy SSE instruction. + // To limit code size increase impact: we only issue VZEROUPPER before PInvoke call, not issue + // VZEROUPPER after PInvoke call because transition penalty from legacy SSE to AVX only happens + // when there's preceding 256-bit AVX to legacy SSE transition penalty. + if (call->IsPInvoke() && (call->gtCallType == CT_USER_FUNC) && getEmitter()->Contains256bitAVX()) + { + assert(compiler->getSIMDInstructionSet() == InstructionSet_AVX); + instGen(INS_vzeroupper); + } +#endif + if (target != nullptr) { #ifdef _TARGET_X86_ diff --git a/src/jit/compiler.cpp b/src/jit/compiler.cpp index 30eccc3ce742..47d3c352c44a 100644 --- a/src/jit/compiler.cpp +++ b/src/jit/compiler.cpp @@ -2310,6 +2310,9 @@ void Compiler::compSetProcessor() if (opts.compCanUseAVX) { codeGen->getEmitter()->SetUseAVX(true); + // Assume each JITted method does not contain AVX instruction at first + codeGen->getEmitter()->SetContainsAVX(false); + codeGen->getEmitter()->SetContains256bitAVX(false); } else #endif // FEATURE_AVX_SUPPORT diff --git a/src/jit/emitxarch.h b/src/jit/emitxarch.h index 98256cdaa707..40f22ed52627 100644 --- a/src/jit/emitxarch.h +++ b/src/jit/emitxarch.h @@ -150,6 +150,26 @@ void SetUseAVX(bool value) useAVXEncodings = value; } +bool containsAVXInstruction = false; +bool ContainsAVX() +{ + return containsAVXInstruction; +} +void SetContainsAVX(bool value) +{ + containsAVXInstruction = value; +} + +bool contains256bitAVXInstruction = false; +bool Contains256bitAVX() +{ + return contains256bitAVXInstruction; +} +void SetContains256bitAVX(bool value) +{ + contains256bitAVXInstruction = value; +} + bool IsThreeOperandBinaryAVXInstruction(instruction ins); bool IsThreeOperandMoveAVXInstruction(instruction ins); bool IsThreeOperandAVXInstruction(instruction ins) @@ -162,6 +182,14 @@ bool UseAVX() { return false; } +bool ContainsAVX() +{ + return false; +} +bool Contains256bitAVX() +{ + return false; +} bool hasVexPrefix(code_t code) { return false; diff --git a/src/jit/lower.h b/src/jit/lower.h index 555b9e26c666..eecc6606ca89 100644 --- a/src/jit/lower.h +++ b/src/jit/lower.h @@ -235,6 +235,7 @@ class Lowering : public Phase #if defined(_TARGET_XARCH_) void SetMulOpCounts(GenTreePtr tree); + void SetContainsAVXFlags(bool isFloatingPointType = true, unsigned sizeOfSIMDVector = 0); #endif // defined(_TARGET_XARCH_) #if !CPU_LOAD_STORE_ARCH diff --git a/src/jit/lowerxarch.cpp b/src/jit/lowerxarch.cpp index bf5d29c596fa..3381060f09af 100644 --- a/src/jit/lowerxarch.cpp +++ b/src/jit/lowerxarch.cpp @@ -166,7 +166,8 @@ void Lowering::TreeNodeInfoInit(GenTree* tree) Compiler* compiler = comp; TreeNodeInfo* info = &(tree->gtLsraInfo); - + // floating type generates AVX instruction (vmovss etc.), set the flag + SetContainsAVXFlags(varTypeIsFloating(tree->TypeGet())); switch (tree->OperGet()) { GenTree* op1; @@ -1773,6 +1774,8 @@ void Lowering::TreeNodeInfoInitBlockStore(GenTreeBlk* blkNode) { MakeSrcContained(blkNode, source); } + // use XMM register to fill with constants, it's AVX instruction and set the flag + SetContainsAVXFlags(); } blkNode->gtBlkOpKind = GenTreeBlk::BlkOpKindUnroll; @@ -1954,6 +1957,9 @@ void Lowering::TreeNodeInfoInitBlockStore(GenTreeBlk* blkNode) // series of 16-byte loads and stores. blkNode->gtLsraInfo.internalFloatCount = 1; blkNode->gtLsraInfo.addInternalCandidates(l, l->internalFloatRegCandidates()); + // Uses XMM reg for load and store and hence check to see whether AVX instructions + // are used for codegen, set ContainsAVX flag + SetContainsAVXFlags(); } // If src or dst are on stack, we don't have to generate the address into a register @@ -2732,6 +2738,7 @@ void Lowering::TreeNodeInfoInitSIMD(GenTree* tree) TreeNodeInfo* info = &(tree->gtLsraInfo); LinearScan* lsra = m_lsra; info->dstCount = 1; + SetContainsAVXFlags(true, simdTree->gtSIMDSize); switch (simdTree->gtSIMDIntrinsicID) { GenTree* op1; @@ -4572,6 +4579,31 @@ void Lowering::SetMulOpCounts(GenTreePtr tree) } } +//------------------------------------------------------------------------------ +// SetContainsAVXFlags: Set ContainsAVX flag when it is floating type, set +// Contains256bitAVX flag when SIMD vector size is 32 bytes +// +// Arguments: +// isFloatingPointType - true if it is floating point type +// sizeOfSIMDVector - SIMD Vector size +// +void Lowering::SetContainsAVXFlags(bool isFloatingPointType /* = true */, unsigned sizeOfSIMDVector /* = 0*/) +{ +#ifdef FEATURE_AVX_SUPPORT + if (isFloatingPointType) + { + if (comp->getFloatingPointInstructionSet() == InstructionSet_AVX) + { + comp->getEmitter()->SetContainsAVX(true); + } + if (sizeOfSIMDVector == 32 && comp->getSIMDInstructionSet() == InstructionSet_AVX) + { + comp->getEmitter()->SetContains256bitAVX(true); + } + } +#endif +} + //------------------------------------------------------------------------------ // isRMWRegOper: Can this binary tree node be used in a Read-Modify-Write format //