diff --git a/src/coreclr/interpreter/compiler.cpp b/src/coreclr/interpreter/compiler.cpp index c184a263bf1bf5..9c2dcaf0194dd8 100644 --- a/src/coreclr/interpreter/compiler.cpp +++ b/src/coreclr/interpreter/compiler.cpp @@ -1215,8 +1215,13 @@ InterpMethod* InterpCompiler::CreateInterpMethod() pDataItems[i] = m_dataItems.Get(i); bool initLocals = (m_methodInfo->options & CORINFO_OPT_INIT_LOCALS) != 0; + CORJIT_FLAGS corJitFlags; + DWORD jitFlagsSize = m_compHnd->getJitFlags(&corJitFlags, sizeof(corJitFlags)); + assert(jitFlagsSize == sizeof(corJitFlags)); - InterpMethod *pMethod = new InterpMethod(m_methodHnd, m_totalVarsStackSize, pDataItems, initLocals); + bool unmanagedCallersOnly = corJitFlags.IsSet(CORJIT_FLAGS::CORJIT_FLAG_REVERSE_PINVOKE); + + InterpMethod *pMethod = new InterpMethod(m_methodHnd, m_totalVarsStackSize, pDataItems, initLocals, unmanagedCallersOnly); return pMethod; } @@ -3038,9 +3043,13 @@ void InterpCompiler::EmitCall(CORINFO_RESOLVED_TOKEN* pConstrainedToken, bool re CORINFO_CONST_LOOKUP lookup; m_compHnd->getAddressOfPInvokeTarget(callInfo.hMethod, &lookup); m_pLastNewIns->data[1] = GetDataItemIndex(lookup.addr); - m_pLastNewIns->data[2] = lookup.accessType == IAT_PVALUE; if (lookup.accessType == IAT_PPVALUE) NO_WAY("IAT_PPVALUE pinvokes not implemented in interpreter"); + bool suppressGCTransition = false; + m_compHnd->getUnmanagedCallConv(callInfo.hMethod, nullptr, &suppressGCTransition); + m_pLastNewIns->data[2] = + ((lookup.accessType == IAT_PVALUE) ? (int32_t)PInvokeCallFlags::Indirect : 0) | + (suppressGCTransition ? (int32_t)PInvokeCallFlags::SuppressGCTransition : 0); } } break; diff --git a/src/coreclr/interpreter/interpretershared.h b/src/coreclr/interpreter/interpretershared.h index b4bb490237ed6d..d9338543d8ea24 100644 --- a/src/coreclr/interpreter/interpretershared.h +++ b/src/coreclr/interpreter/interpretershared.h @@ -35,8 +35,9 @@ struct InterpMethod // This stub is used for calling the interpreted method from JITted/AOTed code CallStubHeader *pCallStub; bool initLocals; + bool unmanagedCallersOnly; - InterpMethod(CORINFO_METHOD_HANDLE methodHnd, int32_t allocaSize, void** pDataItems, bool initLocals) + InterpMethod(CORINFO_METHOD_HANDLE methodHnd, int32_t allocaSize, void** pDataItems, bool initLocals, bool unmanagedCallersOnly) { #if DEBUG this->self = this; @@ -45,6 +46,7 @@ struct InterpMethod this->allocaSize = allocaSize; this->pDataItems = pDataItems; this->initLocals = initLocals; + this->unmanagedCallersOnly = unmanagedCallersOnly; pCallStub = NULL; } @@ -157,4 +159,11 @@ struct InterpGenericLookup uint16_t offsets[InterpGenericLookup_MaxIndirections]; }; +enum class PInvokeCallFlags : int32_t +{ + None = 0, + Indirect = 1 << 0, // The call target address is indirect + SuppressGCTransition = 1 << 1, // The pinvoke is marked by the SuppressGCTransition attribute +}; + #endif diff --git a/src/coreclr/vm/interpexec.cpp b/src/coreclr/vm/interpexec.cpp index 3419a19710d3b3..d81c2a2d0a0b60 100644 --- a/src/coreclr/vm/interpexec.cpp +++ b/src/coreclr/vm/interpexec.cpp @@ -1855,11 +1855,11 @@ void InterpExecMethod(InterpreterFrame *pInterpreterFrame, InterpMethodContextFr callArgsOffset = ip[2]; methodSlot = ip[3]; int32_t targetAddrSlot = ip[4]; - int32_t indirectFlag = ip[5]; + int32_t flags = ip[5]; ip += 6; targetMethod = (MethodDesc*)pMethod->pDataItems[methodSlot]; - PCODE callTarget = indirectFlag + PCODE callTarget = (flags & (int32_t)PInvokeCallFlags::Indirect) ? *(PCODE *)pMethod->pDataItems[targetAddrSlot] : (PCODE)pMethod->pDataItems[targetAddrSlot]; @@ -1872,7 +1872,7 @@ void InterpExecMethod(InterpreterFrame *pInterpreterFrame, InterpMethodContextFr inlinedCallFrame.Push(); { - GCX_PREEMP(); + GCX_MAYBE_PREEMP(!(flags & (int32_t)PInvokeCallFlags::SuppressGCTransition)); InvokeCompiledMethod(targetMethod, stack + callArgsOffset, stack + returnOffset, callTarget); } diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index f9da780c74c417..3ea82cfe8ae4d2 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -2001,6 +2001,8 @@ static InterpThreadContext* GetInterpThreadContext() return threadContext; } +EXTERN_C void STDCALL ReversePInvokeBadTransition(); + extern "C" void* STDCALL ExecuteInterpretedMethod(TransitionBlock* pTransitionBlock, TADDR byteCodeAddr, void* retBuff) { // Argument registers are in the TransitionBlock @@ -2008,6 +2010,21 @@ extern "C" void* STDCALL ExecuteInterpretedMethod(TransitionBlock* pTransitionBl InterpThreadContext *threadContext = GetInterpThreadContext(); int8_t *sp = threadContext->pStackPointer; + InterpByteCodeStart* pInterpreterCode = dac_cast(byteCodeAddr); + + if (pInterpreterCode->Method->unmanagedCallersOnly) + { + Thread* thread = GetThreadNULLOk(); + if (thread == NULL) + CREATETHREAD_IF_NULL_FAILFAST(thread, W("Failed to setup new thread during reverse P/Invoke")); + + // Verify the current thread isn't in COOP mode. + if (thread->PreemptiveGCDisabled()) + ReversePInvokeBadTransition(); + } + + GCX_MAYBE_COOP(pInterpreterCode->Method->unmanagedCallersOnly); + // This construct ensures that the InterpreterFrame is always stored at a higher address than the // InterpMethodContextFrame. This is important for the stack walking code. struct Frames