diff --git a/src/coreclr/debug/daccess/dacdbiimpl.cpp b/src/coreclr/debug/daccess/dacdbiimpl.cpp index a7e1b1cd50cea2..1891248bb85c10 100644 --- a/src/coreclr/debug/daccess/dacdbiimpl.cpp +++ b/src/coreclr/debug/daccess/dacdbiimpl.cpp @@ -1352,7 +1352,7 @@ HRESULT STDMETHODCALLTYPE DacDbiInterfaceImpl::GetNativeCodeInfo(VMPTR_DomainAss MethodDesc* pMethodDesc = FindLoadedMethodRefOrDef(pModule, functionToken); if (pMethodDesc != NULL && pMethodDesc->IsAsyncThunkMethod()) { - MethodDesc* pAsyncVariant = pMethodDesc->GetAsyncOtherVariantNoCreate(); + MethodDesc* pAsyncVariant = pMethodDesc->GetOrdinaryVariantNoCreate(); if (pAsyncVariant != NULL) { pMethodDesc = pAsyncVariant; diff --git a/src/coreclr/vm/asyncthunks.cpp b/src/coreclr/vm/asyncthunks.cpp index 5d72e30f85ae41..30541cc83b94d0 100644 --- a/src/coreclr/vm/asyncthunks.cpp +++ b/src/coreclr/vm/asyncthunks.cpp @@ -25,7 +25,27 @@ bool MethodDesc::TryGenerateAsyncThunk(DynamicResolver** resolver, COR_ILMETHOD_ return false; } - MethodDesc *pAsyncOtherVariant = this->GetAsyncOtherVariant(); + MethodDesc* pAsyncOtherVariant = nullptr; + if (!IsAsyncMethod()) + { + // a non-async thunk is implemented in terms of the async variant which has user code + pAsyncOtherVariant = this->GetAsyncVariant(); + } + else + { + if (!IsReturnDroppingThunk()) + { + // an async thunk is implemented in terms of non-async variant + pAsyncOtherVariant = this->GetOrdinaryVariant(); + } + else + { + // this is a special void-returning async variant that calls + // the normal async variant and drops the result + pAsyncOtherVariant = this->GetAsyncVariant(); + } + } + _ASSERTE(!IsWrapperStub() && !pAsyncOtherVariant->IsWrapperStub()); MetaSig msig(this); @@ -38,13 +58,20 @@ bool MethodDesc::TryGenerateAsyncThunk(DynamicResolver** resolver, COR_ILMETHOD_ pAsyncOtherVariant, (ILStubLinkerFlags)ILSTUB_LINKER_FLAG_NONE); - if (IsAsyncMethod()) + if (!IsAsyncMethod()) { - EmitAsyncMethodThunk(pAsyncOtherVariant, msig, &sl); + EmitTaskReturningThunk(pAsyncOtherVariant, msig, &sl); } else { - EmitTaskReturningThunk(pAsyncOtherVariant, msig, &sl); + if (IsReturnDroppingThunk()) + { + EmitReturnDroppingThunk(pAsyncOtherVariant, msig, &sl); + } + else + { + EmitAsyncMethodThunk(pAsyncOtherVariant, msig, &sl); + } } NewHolder ilResolver = new ILStubResolver(); @@ -132,60 +159,7 @@ void MethodDesc::EmitTaskReturningThunk(MethodDesc* pAsyncCallVariant, MetaSig& pCode->EmitLDARG(localArg++); } - int token; - _ASSERTE(!pAsyncCallVariant->IsWrapperStub()); - if (pAsyncCallVariant->HasClassOrMethodInstantiation()) - { - // For generic code emit generic signatures. - int typeSigToken = mdTokenNil; - if (pAsyncCallVariant->HasClassInstantiation()) - { - SigBuilder typeSigBuilder; - typeSigBuilder.AppendElementType(ELEMENT_TYPE_GENERICINST); - typeSigBuilder.AppendElementType(ELEMENT_TYPE_INTERNAL); - // TODO: (async) Encoding potentially shared method tables in - // signatures of tokens seems odd, but this hits assert - // with the typical method table. - typeSigBuilder.AppendPointer(pAsyncCallVariant->GetMethodTable()); - DWORD numClassTypeArgs = pAsyncCallVariant->GetNumGenericClassArgs(); - typeSigBuilder.AppendData(numClassTypeArgs); - for (DWORD i = 0; i < numClassTypeArgs; ++i) - { - typeSigBuilder.AppendElementType(ELEMENT_TYPE_VAR); - typeSigBuilder.AppendData(i); - } - - DWORD typeSigLen; - PCCOR_SIGNATURE typeSig = (PCCOR_SIGNATURE)typeSigBuilder.GetSignature(&typeSigLen); - typeSigToken = pCode->GetSigToken(typeSig, typeSigLen); - } - - if (pAsyncCallVariant->HasMethodInstantiation()) - { - SigBuilder methodSigBuilder; - DWORD numMethodTypeArgs = pAsyncCallVariant->GetNumGenericMethodArgs(); - methodSigBuilder.AppendByte(IMAGE_CEE_CS_CALLCONV_GENERICINST); - methodSigBuilder.AppendData(numMethodTypeArgs); - for (DWORD i = 0; i < numMethodTypeArgs; ++i) - { - methodSigBuilder.AppendElementType(ELEMENT_TYPE_MVAR); - methodSigBuilder.AppendData(i); - } - - DWORD sigLen; - PCCOR_SIGNATURE sig = (PCCOR_SIGNATURE)methodSigBuilder.GetSignature(&sigLen); - int methodSigToken = pCode->GetSigToken(sig, sigLen); - token = pCode->GetToken(pAsyncCallVariant, typeSigToken, methodSigToken); - } - else - { - token = pCode->GetToken(pAsyncCallVariant, typeSigToken); - } - } - else - { - token = pCode->GetToken(pAsyncCallVariant); - } + int token = GetTokenForThunkTarget(pCode, pAsyncCallVariant); pCode->EmitCALL(token, localArg, logicalResultLocal != UINT_MAX ? 1 : 0); @@ -429,47 +403,15 @@ int MethodDesc::GetTokenForGenericTypeMethodCallWithAsyncReturnType(ILCodeStream return pCode->GetToken(md, typeSigToken); } -// Provided a Task-returning method, emits an async wrapper. -// The emitted code matches method EmitAsyncMethodThunk in the Managed Type System. -void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig& msig, ILStubLinker* pSL) +int MethodDesc::GetTokenForThunkTarget(ILCodeStream* pCode, MethodDesc* md) { - _ASSERTE(!pTaskReturningVariant->IsAsyncThunkMethod()); - _ASSERTE(!pTaskReturningVariant->IsVoid()); - - // Implement IL that is effectively the following: - // { - // Task task = other(arg); - // if (!task.IsCompleted) - // { - // // Magic function which will suspend the current run of async methods - // AsyncHelpers.TransparentAwait(task); - // } - // return AsyncHelpers.CompletedTaskResult(task); - // } - - // For ValueTask: - - // { - // ValueTask vt = other(arg); - // if (!vt.IsCompleted) - // { - // taskOrNotifier = vt.AsTaskOrNotifier() - - // // Magic function which will suspend the current run of async methods - // AsyncHelpers.TransparentAwait(taskOrNotifier); - // } - - // return vt.Result/vt.ThrowIfCompletedUnsuccessfully(); - // } - ILCodeStream* pCode = pSL->NewCodeStream(ILStubLinker::kDispatch); - - int userFuncToken; - _ASSERTE(!pTaskReturningVariant->IsWrapperStub()); - if (pTaskReturningVariant->HasClassOrMethodInstantiation()) + int token; + _ASSERTE(!md->IsWrapperStub()); + if (md->HasClassOrMethodInstantiation()) { // For generic code emit generic signatures. int typeSigToken = mdTokenNil; - if (pTaskReturningVariant->HasClassInstantiation()) + if (md->HasClassInstantiation()) { SigBuilder typeSigBuilder; typeSigBuilder.AppendElementType(ELEMENT_TYPE_GENERICINST); @@ -477,8 +419,8 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig // TODO: (async) Encoding potentially shared method tables in // signatures of tokens seems odd, but this hits assert // with the typical method table. - typeSigBuilder.AppendPointer(pTaskReturningVariant->GetMethodTable()); - DWORD numClassTypeArgs = pTaskReturningVariant->GetNumGenericClassArgs(); + typeSigBuilder.AppendPointer(md->GetMethodTable()); + DWORD numClassTypeArgs = md->GetNumGenericClassArgs(); typeSigBuilder.AppendData(numClassTypeArgs); for (DWORD i = 0; i < numClassTypeArgs; ++i) { @@ -491,10 +433,10 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig typeSigToken = pCode->GetSigToken(typeSig, typeSigLen); } - if (pTaskReturningVariant->HasMethodInstantiation()) + if (md->HasMethodInstantiation()) { SigBuilder methodSigBuilder; - DWORD numMethodTypeArgs = pTaskReturningVariant->GetNumGenericMethodArgs(); + DWORD numMethodTypeArgs = md->GetNumGenericMethodArgs(); methodSigBuilder.AppendByte(IMAGE_CEE_CS_CALLCONV_GENERICINST); methodSigBuilder.AppendData(numMethodTypeArgs); for (DWORD i = 0; i < numMethodTypeArgs; ++i) @@ -506,18 +448,57 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig DWORD sigLen; PCCOR_SIGNATURE sig = (PCCOR_SIGNATURE)methodSigBuilder.GetSignature(&sigLen); int methodSigToken = pCode->GetSigToken(sig, sigLen); - userFuncToken = pCode->GetToken(pTaskReturningVariant, typeSigToken, methodSigToken); + token = pCode->GetToken(md, typeSigToken, methodSigToken); } else { - userFuncToken = pCode->GetToken(pTaskReturningVariant, typeSigToken); + token = pCode->GetToken(md, typeSigToken); } } else { - userFuncToken = pCode->GetToken(pTaskReturningVariant); + token = pCode->GetToken(md); } + return token; +} + +// Provided a Task-returning method, emits an async wrapper. +// The emitted code matches method EmitAsyncMethodThunk in the Managed Type System. +void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig& msig, ILStubLinker* pSL) +{ + _ASSERTE(!pTaskReturningVariant->IsAsyncThunkMethod()); + _ASSERTE(!pTaskReturningVariant->IsVoid()); + + // Implement IL that is effectively the following: + // { + // Task task = other(arg); + // if (!task.IsCompleted) + // { + // // Magic function which will suspend the current run of async methods + // AsyncHelpers.TransparentAwait(task); + // } + // return AsyncHelpers.CompletedTaskResult(task); + // } + + // For ValueTask: + + // { + // ValueTask vt = other(arg); + // if (!vt.IsCompleted) + // { + // taskOrNotifier = vt.AsTaskOrNotifier() + + // // Magic function which will suspend the current run of async methods + // AsyncHelpers.TransparentAwait(taskOrNotifier); + // } + + // return vt.Result/vt.ThrowIfCompletedUnsuccessfully(); + // } + ILCodeStream* pCode = pSL->NewCodeStream(ILStubLinker::kDispatch); + + int token = GetTokenForThunkTarget(pCode, pTaskReturningVariant); + DWORD localArg = 0; if (msig.HasThis()) { @@ -532,11 +513,11 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig if (pTaskReturningVariant->IsAbstract()) { _ASSERTE(pTaskReturningVariant->IsCLRToCOMCall()); - pCode->EmitCALLVIRT(userFuncToken, localArg, 1); + pCode->EmitCALLVIRT(token, localArg, 1); } else { - pCode->EmitCALL(userFuncToken, localArg, 1); + pCode->EmitCALL(token, localArg, 1); } TypeHandle thLogicalRetType = msig.GetRetTypeHandleThrowing(); @@ -637,3 +618,36 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig pCode->EmitRET(); } } + +// Provided an async variant, emits an async wrapper that drops the returned value. +// Used in the covariant return scenario. +void MethodDesc::EmitReturnDroppingThunk(MethodDesc* pAsyncOtherVariant, MetaSig& msig, ILStubLinker* pSL) +{ + _ASSERTE(pAsyncOtherVariant->IsAsyncVariantMethod()); + + _ASSERTE(!pAsyncOtherVariant->IsVoid()); + _ASSERTE(pAsyncOtherVariant->IsVirtual()); + _ASSERTE(this->IsVoid()); + _ASSERTE(this->IsVirtual()); + + // Implement IL that is effectively the following: + // { + // this.other(arg); // CALLVIRT + // return; + // } + ILCodeStream* pCode = pSL->NewCodeStream(ILStubLinker::kDispatch); + int token = GetTokenForThunkTarget(pCode, pAsyncOtherVariant); + + DWORD localArg = 0; + pCode->EmitLDARG(localArg++); + for (UINT iArg = 0; iArg < msig.NumFixedArgs(); iArg++) + { + pCode->EmitLDARG(localArg++); + } + + // other(arg) + pCode->EmitCALLVIRT(token, localArg, 1); + // return; + pCode->EmitPOP(); + pCode->EmitRET(); +} diff --git a/src/coreclr/vm/class.cpp b/src/coreclr/vm/class.cpp index c532eae900b5b4..dbc9b2514eb188 100644 --- a/src/coreclr/vm/class.cpp +++ b/src/coreclr/vm/class.cpp @@ -1366,12 +1366,6 @@ void ClassLoader::ValidateMethodsWithCovariantReturnTypes(MethodTable* pMT) continue; } MethodDesc* pMD = pMT->GetMethodDescForSlot(i); - - // Skip validation for async variant methods, as they have different signatures by design - // to support the async calling convention - if (pMD->IsAsyncVariantMethod()) - continue; - MethodDesc* pParentMD = pParentMT->GetMethodDescForSlot(i); if (pMD == pParentMD) diff --git a/src/coreclr/vm/genmeth.cpp b/src/coreclr/vm/genmeth.cpp index 6dd9bdda0a1f32..62f413ada24e67 100644 --- a/src/coreclr/vm/genmeth.cpp +++ b/src/coreclr/vm/genmeth.cpp @@ -748,9 +748,9 @@ MethodDesc::FindOrCreateAssociatedMethodDesc(MethodDesc* pDefMD, BOOL forceBoxedEntryPoint, Instantiation methodInst, BOOL allowInstParam, + AsyncVariantLookup asyncVariantLookup, BOOL forceRemotableMethod, BOOL allowCreate, - AsyncVariantLookup asyncVariantLookup, ClassLoadLevel level) { CONTRACT(MethodDesc*) @@ -788,7 +788,7 @@ MethodDesc::FindOrCreateAssociatedMethodDesc(MethodDesc* pDefMD, methodInst.IsEmpty() && !forceBoxedEntryPoint && !pDefMD->IsUnboxingStub() && - asyncVariantLookup == AsyncVariantLookup::MatchingAsyncVariant) + pDefMD->MatchesAsyncVariantLookup(asyncVariantLookup)) { // Make sure that pDefMD->GetMethodTable() and pExactMT are related types even // if we took the fast path. @@ -817,7 +817,9 @@ MethodDesc::FindOrCreateAssociatedMethodDesc(MethodDesc* pDefMD, COMPlusThrowHR(COR_E_TYPELOAD); } - if (pDefMD->HasClassOrMethodInstantiation() || !methodInst.IsEmpty() || asyncVariantLookup == AsyncVariantLookup::AsyncOtherVariant) + if (pDefMD->HasClassOrMethodInstantiation() || + !methodInst.IsEmpty() || + !pDefMD->MatchesAsyncVariantLookup(asyncVariantLookup)) { // General checks related to generics: arity (if any) must match and generic method // instantiation (if any) must be well-formed. @@ -845,8 +847,8 @@ MethodDesc::FindOrCreateAssociatedMethodDesc(MethodDesc* pDefMD, if ( methodInst.IsEmpty() && (allowInstParam || !pMDescInCanonMT->RequiresInstArg()) && (forceBoxedEntryPoint == pMDescInCanonMT->IsUnboxingStub()) - && (!forceRemotableMethod || !pMDescInCanonMT->IsInterface() - || !pMDescInCanonMT->GetMethodTable()->IsSharedByGenericInstantiations()) ) + && (!forceRemotableMethod || !pMDescInCanonMT->IsInterface() || !pMDescInCanonMT->GetMethodTable()->IsSharedByGenericInstantiations()) + && (pMDescInCanonMT->MatchesAsyncVariantLookup(asyncVariantLookup))) { RETURN(pMDescInCanonMT); } @@ -992,7 +994,10 @@ MethodDesc::FindOrCreateAssociatedMethodDesc(MethodDesc* pDefMD, pExactMT, FALSE /* not Unboxing */, methodInst, - FALSE, FALSE, TRUE, asyncVariantLookup); + FALSE, + asyncVariantLookup, + FALSE, + TRUE); _ASSERTE(pNonUnboxingStub->GetClassification() == mcInstantiated); _ASSERTE(!pNonUnboxingStub->RequiresInstArg()); @@ -1200,9 +1205,9 @@ MethodDesc::FindOrCreateAssociatedMethodDesc(MethodDesc* pDefMD, FALSE, Instantiation(repInst, methodInst.GetNumArgs()), /* allowInstParam */ TRUE, + asyncVariantLookup, /* forceRemotableMethod */ FALSE, /* allowCreate */ TRUE, - asyncVariantLookup, /* level */ level); _ASSERTE(pWrappedMD->IsSharedByGenericInstantiations()); diff --git a/src/coreclr/vm/jitinterface.cpp b/src/coreclr/vm/jitinterface.cpp index 0724c7ac359723..8041129865a390 100644 --- a/src/coreclr/vm/jitinterface.cpp +++ b/src/coreclr/vm/jitinterface.cpp @@ -8997,10 +8997,15 @@ CORINFO_METHOD_HANDLE CEEInfo::getAsyncOtherVariant( MethodDesc* pMD = GetMethod(ftn); MethodDesc* pAsyncOtherVariant = NULL; - if (pMD->HasAsyncOtherVariant()) + if (pMD->ReturnsTaskOrValueTask()) { - pAsyncOtherVariant = pMD->GetAsyncOtherVariant(); + pAsyncOtherVariant = pMD->GetAsyncVariant(); } + else if (pMD->IsAsyncVariantMethod()) + { + pAsyncOtherVariant = pMD->GetOrdinaryVariant(); + } + result = (CORINFO_METHOD_HANDLE)pAsyncOtherVariant; *variantIsThunk = pAsyncOtherVariant != NULL && pAsyncOtherVariant->IsAsyncThunkMethod(); diff --git a/src/coreclr/vm/memberload.cpp b/src/coreclr/vm/memberload.cpp index ed734b8da3fac0..7f91a27ab4a490 100644 --- a/src/coreclr/vm/memberload.cpp +++ b/src/coreclr/vm/memberload.cpp @@ -781,7 +781,6 @@ MemberLoader::GetMethodDescFromMemberDefOrRefOrSpec( allowInstParam, /* forceRemotableMethod */ FALSE, /* allowCreate */ TRUE, - AsyncVariantLookup::MatchingAsyncVariant, /* level */ owningTypeLoadLevel); } // MemberLoader::GetMethodDescFromMemberDefOrRefOrSpec diff --git a/src/coreclr/vm/method.cpp b/src/coreclr/vm/method.cpp index 9295fb3d38e534..6e4772c539db5c 100644 --- a/src/coreclr/vm/method.cpp +++ b/src/coreclr/vm/method.cpp @@ -2390,7 +2390,7 @@ bool IsTypeDefOrRefImplementedInSystemModule(Module* pModule, mdToken tk) return false; } -MethodReturnKind ClassifyMethodReturnKind(SigPointer sig, Module* pModule, ULONG* offsetOfAsyncDetails, bool *isValueTask) +MethodReturnKind ClassifyMethodReturnKind(SigPointer sig, Module* pModule, ULONG* offsetOfAsyncDetails, ULONG* elementTypeLength, bool *isValueTask) { PCCOR_SIGNATURE initialSig = sig.GetPtr(); uint32_t data; @@ -2433,7 +2433,12 @@ MethodReturnKind ClassifyMethodReturnKind(SigPointer sig, Module* pModule, ULONG if ((strcmp(name, *isValueTask ? "ValueTask`1" : "Task`1") == 0) && strcmp(_namespace, "System.Threading.Tasks") == 0) { if (IsTypeDefOrRefImplementedInSystemModule(pModule, tk)) + { + PCCOR_SIGNATURE elementStart = sig.GetPtr(); + sig.SkipExactlyOne(); + *elementTypeLength = (ULONG)(sig.GetPtr() - elementStart); return MethodReturnKind::GenericTaskReturningMethod; + } } } } diff --git a/src/coreclr/vm/method.hpp b/src/coreclr/vm/method.hpp index e196bf575bb9ce..9188d081cd5ee4 100644 --- a/src/coreclr/vm/method.hpp +++ b/src/coreclr/vm/method.hpp @@ -76,6 +76,8 @@ enum class AsyncMethodFlags IsAsyncVariantForValueTask = 8, // Method has synthetic body, which forwards to the other variant. Thunk = 16, + // A special thunk to drop return value in covariant return scenario + ReturnDroppingThunk = 32, // The rest of the methods that are not in any of the above groups. // Such methods are not interesting to the Runtime Async feature. // Note: Generic T-returning methods are classified as "None", even if T could be a Task. @@ -95,7 +97,7 @@ enum class AsyncMethodFlags // Example: "Task Foo();" ===> "int Foo();" // Example: "ValueTask Bar();" ===> "void Bar();" // - // It is possible to get from one variant to another via GetAsyncOtherVariant. + // It is possible to get from one variant to another via GetAsyncVariant/GetOrdinaryVariant. // // NOTE: Not all AsyncCall methods are "variants" from a pair. // Methods that are explicitly declared as MethodImpl.Async in metadata while @@ -265,8 +267,8 @@ using PTR_MethodDescCodeData = DPTR(MethodDescCodeData); enum class AsyncVariantLookup { - MatchingAsyncVariant = 0, - AsyncOtherVariant + Ordinary = 0, + Async }; enum class MethodReturnKind @@ -277,7 +279,7 @@ enum class MethodReturnKind }; bool IsTypeDefOrRefImplementedInSystemModule(Module* pModule, mdToken tk); -MethodReturnKind ClassifyMethodReturnKind(SigPointer sig, Module* pModule, ULONG* offsetOfAsyncDetails, bool *isValueTask); +MethodReturnKind ClassifyMethodReturnKind(SigPointer sig, Module* pModule, ULONG* offsetOfAsyncDetails, ULONG* elementTypeLength, bool *isValueTask); inline bool IsTaskReturning(MethodReturnKind input) { @@ -1701,45 +1703,70 @@ class MethodDesc BOOL forceBoxedEntryPoint, Instantiation methodInst, BOOL allowInstParam, + AsyncVariantLookup variantLookup, BOOL forceRemotableMethod = FALSE, BOOL allowCreate = TRUE, - AsyncVariantLookup variantLookup = AsyncVariantLookup::MatchingAsyncVariant, ClassLoadLevel level = CLASS_LOADED); - // Normalize methoddesc for reflection - static MethodDesc* FindOrCreateAssociatedMethodDescForReflection(MethodDesc *pMethod, - TypeHandle instType, - Instantiation methodInst); - - inline bool HasAsyncOtherVariant() const - { - return IsAsyncVariantMethod() || ReturnsTaskOrValueTask(); + // Common Case: same async variant kind as pPrimaryMD + static MethodDesc* FindOrCreateAssociatedMethodDesc(MethodDesc* pPrimaryMD, + MethodTable* pExactMT, + BOOL forceBoxedEntryPoint, + Instantiation methodInst, + BOOL allowInstParam, + BOOL forceRemotableMethod = FALSE, + BOOL allowCreate = TRUE, + ClassLoadLevel level = CLASS_LOADED) + { + // If this assert fires, we may just need to add a lookup that matches AsyncMethodFlags::ReturnDroppingThunk + // It does not look like there is a scenario for directly calling ReturnDroppingThunk right now. + _ASSERTE(!pPrimaryMD->IsReturnDroppingThunk()); + // by default async lookup matches the primaryMD + AsyncVariantLookup variantLookup = pPrimaryMD->IsAsyncVariantMethod() ? AsyncVariantLookup::Async : AsyncVariantLookup::Ordinary; + + return FindOrCreateAssociatedMethodDesc( + pPrimaryMD, + pExactMT, + forceBoxedEntryPoint, + methodInst, + allowInstParam, + variantLookup, + forceRemotableMethod, + allowCreate, + level); } - MethodDesc* GetAsyncOtherVariant(BOOL allowInstParam = TRUE) + // Normalize methoddesc for reflection + static MethodDesc* FindOrCreateAssociatedMethodDescForReflection(MethodDesc* pMethod, + TypeHandle instType, + Instantiation methodInst); + + MethodDesc* GetOrdinaryVariant(BOOL allowInstParam = TRUE) { - _ASSERTE(HasAsyncOtherVariant()); - return FindOrCreateAssociatedMethodDesc(this, GetMethodTable(), FALSE, GetMethodInstantiation(), allowInstParam, FALSE, TRUE, AsyncVariantLookup::AsyncOtherVariant); + MethodTable* mt = GetMethodTable(); + return FindOrCreateAssociatedMethodDesc(this, mt, FALSE, GetMethodInstantiation(), allowInstParam, AsyncVariantLookup::Ordinary, FALSE, TRUE, mt->GetLoadLevel()); } - MethodDesc* GetAsyncOtherVariantNoCreate(BOOL allowInstParam = TRUE) + // same as above, but with allowCreate = FALSE + // for rare cases where we cannot allow GC, but we know that the other variant is already created. + MethodDesc* GetOrdinaryVariantNoCreate(BOOL allowInstParam = TRUE) { - _ASSERTE(HasAsyncOtherVariant()); - return FindOrCreateAssociatedMethodDesc(this, GetMethodTable(), FALSE, GetMethodInstantiation(), allowInstParam, FALSE, FALSE, AsyncVariantLookup::AsyncOtherVariant); + MethodTable* mt = GetMethodTable(); + return FindOrCreateAssociatedMethodDesc(this, mt, FALSE, GetMethodInstantiation(), allowInstParam, AsyncVariantLookup::Ordinary, FALSE, FALSE, mt->GetLoadLevel()); } MethodDesc* GetAsyncVariant(BOOL allowInstParam = TRUE) { - _ASSERT(!IsAsyncVariantMethod()); - return FindOrCreateAssociatedMethodDesc(this, GetMethodTable(), FALSE, GetMethodInstantiation(), allowInstParam, FALSE, TRUE, AsyncVariantLookup::AsyncOtherVariant); + MethodTable* mt = GetMethodTable(); + return FindOrCreateAssociatedMethodDesc(this, mt, FALSE, GetMethodInstantiation(), allowInstParam, AsyncVariantLookup::Async, FALSE, TRUE, mt->GetLoadLevel()); } // same as above, but with allowCreate = FALSE // for rare cases where we cannot allow GC, but we know that the other variant is already created. MethodDesc* GetAsyncVariantNoCreate(BOOL allowInstParam = TRUE) { - _ASSERT(!IsAsyncVariantMethod()); - return FindOrCreateAssociatedMethodDesc(this, GetMethodTable(), FALSE, GetMethodInstantiation(), allowInstParam, FALSE, FALSE, AsyncVariantLookup::AsyncOtherVariant); + MethodTable* mt = GetMethodTable(); + return FindOrCreateAssociatedMethodDesc(this, mt, FALSE, GetMethodInstantiation(), allowInstParam, AsyncVariantLookup::Async, FALSE, FALSE, mt->GetLoadLevel()); } // True if a MD is an funny BoxedEntryPointStub (not from the method table) or @@ -2017,6 +2044,38 @@ class MethodDesc return hasAsyncFlags(asyncFlags, AsyncMethodFlags::IsAsyncVariant); } + inline bool IsReturnDroppingThunk() const + { + LIMITED_METHOD_DAC_CONTRACT; + if (!HasAsyncMethodData()) + return false; + + AsyncMethodFlags asyncFlags = GetAddrOfAsyncMethodData()->flags; + return hasAsyncFlags(asyncFlags, AsyncMethodFlags::ReturnDroppingThunk); + } + + inline bool MatchesAsyncVariantLookup(AsyncVariantLookup lookup) const + { + LIMITED_METHOD_DAC_CONTRACT; + + if (lookup == AsyncVariantLookup::Ordinary) + return !IsAsyncVariantMethod(); + + if (lookup == AsyncVariantLookup::Async) + { + if (!HasAsyncMethodData()) + return false; + + // Note: AsyncVariantLookup::Async only matches regular async variants. ReturnDroppingThunk intentionally + // does not match any lookups. Noone should call ReturnDroppingThunk directly. The only way it gets + // invoked is when it adds itself as a virtual override to a regular async variant. + AsyncMethodFlags asyncFlags = GetAddrOfAsyncMethodData()->flags; + return hasAsyncFlags(asyncFlags, AsyncMethodFlags::IsAsyncVariant) && !hasAsyncFlags(asyncFlags, AsyncMethodFlags::ReturnDroppingThunk); + } + + return false; + } + // Is this an Async variant method for a method that // returns ValueTask or ValueTask ? inline bool IsAsyncVariantForValueTaskReturningMethod() const @@ -2225,7 +2284,9 @@ class MethodDesc bool TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMETHOD_DECODER** methodILDecoder); void EmitTaskReturningThunk(MethodDesc* pAsyncCallVariant, MetaSig& thunkMsig, ILStubLinker* pSL); void EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig& msig, ILStubLinker* pSL); + void EmitReturnDroppingThunk(MethodDesc* pAsyncOtherVariant, MetaSig& msig, ILStubLinker* pSL); SigPointer GetAsyncThunkResultTypeSig(); + int GetTokenForThunkTarget(ILCodeStream* pCode, MethodDesc* md); int GetTokenForGenericMethodCallWithAsyncReturnType(ILCodeStream* pCode, MethodDesc* md); int GetTokenForGenericTypeMethodCallWithAsyncReturnType(ILCodeStream* pCode, MethodDesc* md); public: diff --git a/src/coreclr/vm/methodtable.cpp b/src/coreclr/vm/methodtable.cpp index 0234a880c986fe..79f080cc339280 100644 --- a/src/coreclr/vm/methodtable.cpp +++ b/src/coreclr/vm/methodtable.cpp @@ -5695,7 +5695,6 @@ namespace FALSE, // allowInstParam TRUE, // forceRemoteableMethod TRUE, // allowCreate - AsyncVariantLookup::MatchingAsyncVariant, level // level ); } @@ -7919,7 +7918,7 @@ namespace } } -MethodDesc* MethodTable::GetParallelMethodDesc(MethodDesc* pDefMD, AsyncVariantLookup asyncVariantLookup) +MethodDesc* MethodTable::GetParallelMethodDesc(MethodDesc* pDefMD) { CONTRACTL { @@ -7929,30 +7928,43 @@ MethodDesc* MethodTable::GetParallelMethodDesc(MethodDesc* pDefMD, AsyncVariantL } CONTRACTL_END; - if (asyncVariantLookup == AsyncVariantLookup::MatchingAsyncVariant) - { #ifdef FEATURE_METADATA_UPDATER - if (pDefMD->IsEnCAddedMethod()) - return GetParallelMethodDescForEnC(this, pDefMD); + if (pDefMD->IsEnCAddedMethod()) + { + return GetParallelMethodDescForEnC(this, pDefMD); + } #endif // FEATURE_METADATA_UPDATER - return GetMethodDescForSlot_NoThrow(pDefMD->GetSlot()); // TODO! We should probably use the throwing variant where possible + return GetMethodDescForSlot_NoThrow(pDefMD->GetSlot()); +} + +MethodDesc* MethodTable::GetParallelMethodDesc(MethodDesc* pDefMD, AsyncVariantLookup asyncVariantLookup) +{ + CONTRACTL + { + NOTHROW; + GC_NOTRIGGER; + MODE_ANY; + } + CONTRACTL_END; + + if (pDefMD->MatchesAsyncVariantLookup(asyncVariantLookup)) + { + return GetParallelMethodDesc(pDefMD); } else { - // Slow path for finding the Async variant (or not-Async variant, if we start from Async one) + // Slow path for finding the matching async variant. // This could be optimized with some trickery around slot numbers, but doing so is ... confusing, so I'm not implementing this yet mdMethodDef tkMethod = pDefMD->GetMemberDef(); Module* mod = pDefMD->GetModule(); - bool isAsyncVariantMethod = pDefMD->IsAsyncVariantMethod(); - MethodTable::IntroducedMethodIterator it(this); for (; it.IsValid(); it.Next()) { MethodDesc* pMD = it.GetMethodDesc(); if (pMD->GetMemberDef() == tkMethod && pMD->GetModule() == mod - && pMD->IsAsyncVariantMethod() != isAsyncVariantMethod) + && (pMD->MatchesAsyncVariantLookup(asyncVariantLookup))) { return pMD; } @@ -8341,18 +8353,19 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType } bool differsByAsyncVariant = false; + _ASSERTE(!pMethodDecl->IsAsyncVariantMethod()); if (!pMethodDecl->HasSameMethodDefAs(pInterfaceMD)) { if (pMethodDecl->GetMemberDef() == pInterfaceMD->GetMemberDef() && pMethodDecl->GetModule() == pInterfaceMD->GetModule() && - pMethodDecl->IsAsyncVariantMethod() != pInterfaceMD->IsAsyncVariantMethod()) + pInterfaceMD->IsAsyncVariantMethod()) { differsByAsyncVariant = true; - pMethodDecl = pMethodDecl->GetAsyncOtherVariant(); + pMethodDecl = pMethodDecl->GetAsyncVariant(); if (verifyImplemented) { // if only asked to verify, return pMethodDecl as a success (not NULL) - // otherwise GetAsyncOtherVariant down below will trigger verifying again and we will keep coming here + // otherwise GetAsyncVariant down below will trigger verifying again and we will keep coming here _ASSERTE(pMethodDecl != NULL); return pMethodDecl; } @@ -8388,7 +8401,7 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType if (differsByAsyncVariant) { - pMethodImpl = pMethodImpl->GetAsyncOtherVariant(); + pMethodImpl = pMethodImpl->GetAsyncVariant(); } if (!verifyImplemented && instantiateMethodParameters) @@ -8401,7 +8414,6 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType /* allowInstParam */ FALSE, /* forceRemotableMethod */ FALSE, /* allowCreate */ TRUE, - AsyncVariantLookup::MatchingAsyncVariant, /* level */ level); } if (pMethodImpl != nullptr) diff --git a/src/coreclr/vm/methodtable.h b/src/coreclr/vm/methodtable.h index 23cc148a09fe57..b4bd831b87f604 100644 --- a/src/coreclr/vm/methodtable.h +++ b/src/coreclr/vm/methodtable.h @@ -1805,8 +1805,10 @@ class MethodTable // Returns MethodTable that GetRestoredSlot get its values from MethodTable * GetRestoredSlotMT(DWORD slot); - // Used to map methods on the same slot between instantiations. - MethodDesc * GetParallelMethodDesc(MethodDesc * pDefMD, AsyncVariantLookup asyncVariantLookup = (AsyncVariantLookup)0); + // Used to map to "the same" method between instantiations. + MethodDesc* GetParallelMethodDesc(MethodDesc* pDefMD); + // Maps methods between instantiations + filters/adjusts the result according to the lookup. + MethodDesc* GetParallelMethodDesc(MethodDesc* pDefMD, AsyncVariantLookup asyncVariantLookup); //------------------------------------------------------------------- // BoxedEntryPoint MethodDescs. diff --git a/src/coreclr/vm/methodtablebuilder.cpp b/src/coreclr/vm/methodtablebuilder.cpp index aba716ed8fabdc..038fa7a8d61417 100644 --- a/src/coreclr/vm/methodtablebuilder.cpp +++ b/src/coreclr/vm/methodtablebuilder.cpp @@ -2697,8 +2697,9 @@ MethodTableBuilder::EnumerateClassMethods() // In a worst case the number of declared methods can double // as each async method may have two variants. // The method count is typically a modest number though. - // We will reserve twice the size for the builder, up to the max, just in case. - DWORD cMethUpperBound = cMethAndGaps * 2; + // If we have covariant overrides, such as a base Task method overridden by Task, we will need 3 method descs. + // Reserve the space conservatively, up to the max, for the worst case scenario. + DWORD cMethUpperBound = cMethAndGaps * (bmtMetaData->fHasCovariantOverride ? 3 : 2); if ((DWORD)MAX_SLOT_INDEX <= cMethUpperBound) { cMethUpperBound = MAX_SLOT_INDEX - 1; @@ -2782,10 +2783,11 @@ MethodTableBuilder::EnumerateClassMethods() SigParser sig(pMemberSignature, cMemberSignature); ULONG offsetOfAsyncDetails = 0; + ULONG elementTypeLength = 0; bool returnsValueTask = false; MethodReturnKind returnKind = IsDelegate() ? MethodReturnKind::NormalMethod : - ClassifyMethodReturnKind(sig, GetModule(), &offsetOfAsyncDetails, &returnsValueTask); + ClassifyMethodReturnKind(sig, GetModule(), &offsetOfAsyncDetails, &elementTypeLength, &returnsValueTask); bool hasGenericMethodArgsComputed = false; bool hasGenericMethodArgs = this->GetModule()->m_pMethodIsGenericMap->IsGeneric(tok, &hasGenericMethodArgsComputed); @@ -3338,8 +3340,7 @@ MethodTableBuilder::EnumerateClassMethods() // Create a new bmtMDMethod representing this method and add it to the // declared method list. // - bmtMDMethod *pDeclaredMethod = NULL; - for (int insertCount = 0; insertCount < 2; insertCount++) + for (int insertCount = 0; insertCount < 3; insertCount++) { if (bmtMethod->m_cDeclaredMethods >= bmtMethod->m_cMaxDeclaredMethods) { @@ -3391,12 +3392,10 @@ MethodTableBuilder::EnumerateClassMethods() pNewMethod->SetAsyncMethodFlags(AsyncMethodFlags::None); } } - - pDeclaredMethod = pNewMethod; } else { - // Second pass, add the async variant. + // Extra pass, add an async variant. ULONG cAsyncThunkMemberSignature; ULONG taskTokenOffsetFromAsyncDetailsOffset; @@ -3413,20 +3412,42 @@ MethodTableBuilder::EnumerateClassMethods() if (!IsMiAsync(dwImplFlags)) asyncFlags |= AsyncMethodFlags::Thunk; + if (insertCount == 2) + asyncFlags |= (AsyncMethodFlags::Thunk | AsyncMethodFlags::ReturnDroppingThunk); + // Here we construct the signature of async call variant given its task-returning counterpart. // It is basically just removing the Task/ValueTask part of the return type and keeping // the token for T or inserting void instead. // The rest of the signature stays exactly the same. - ULONG tokenLen = 0; - if (returnKind == MethodReturnKind::NonGenericTaskReturningMethod) + ULONG taskTokenLen = 0; + + if (insertCount == 2) + { + // This is a rare case when we need two async variants and this is the second one. + // The need arises when a Task-returning method has a Task returning virtual override. + // We need an extra void-returning thunk that can override the void-returning async variant in the base. + // The thunk's implementation simply calls the T-returning async variant and ignores the return. + + // from ". . . Task . . . Method(args);" we construct + // ". . . void . . . Method(args);" + + taskTokenOffsetFromAsyncDetailsOffset = 2; + taskTokenLen = CorSigUncompressedDataSize(&pMemberSignature[offsetOfAsyncDetails + taskTokenOffsetFromAsyncDetailsOffset]); + + taskTypePrefixSize = 2 + taskTokenLen + 1 + elementTypeLength; // E_T_GENERICINST E_T_CLASS/E_T_VALUETYPE 1 + taskTypePrefixReplacementSize = 1; // ELEMENT_TYPE_VOID + + cAsyncThunkMemberSignature = cMemberSignature - taskTypePrefixSize + taskTypePrefixReplacementSize; + } + else if (returnKind == MethodReturnKind::NonGenericTaskReturningMethod) { // from ". . . Task . . . Method(args);" we construct // ". . . void . . . Method(args);" taskTokenOffsetFromAsyncDetailsOffset = 1; - tokenLen = CorSigUncompressedDataSize(&pMemberSignature[offsetOfAsyncDetails + taskTokenOffsetFromAsyncDetailsOffset]); + taskTokenLen = CorSigUncompressedDataSize(&pMemberSignature[offsetOfAsyncDetails + taskTokenOffsetFromAsyncDetailsOffset]); - taskTypePrefixSize = 1 + tokenLen; // E_T_CLASS/E_T_VALUETYPE + taskTypePrefixSize = 1 + taskTokenLen; // E_T_CLASS/E_T_VALUETYPE taskTypePrefixReplacementSize = 1; // ELEMENT_TYPE_VOID cAsyncThunkMemberSignature = cMemberSignature - taskTypePrefixSize + taskTypePrefixReplacementSize; @@ -3437,9 +3458,9 @@ MethodTableBuilder::EnumerateClassMethods() // ". . . tk . . . Method(args);" taskTokenOffsetFromAsyncDetailsOffset = 2; - tokenLen = CorSigUncompressedDataSize(&pMemberSignature[offsetOfAsyncDetails + taskTokenOffsetFromAsyncDetailsOffset]); + taskTokenLen = CorSigUncompressedDataSize(&pMemberSignature[offsetOfAsyncDetails + taskTokenOffsetFromAsyncDetailsOffset]); - taskTypePrefixSize = 2 + tokenLen + 1; // E_T_GENERICINST E_T_CLASS/E_T_VALUETYPE 1 + taskTypePrefixSize = 2 + taskTokenLen + 1; // E_T_GENERICINST E_T_CLASS/E_T_VALUETYPE 1 taskTypePrefixReplacementSize = 0; cAsyncThunkMemberSignature = cMemberSignature - taskTypePrefixSize + taskTypePrefixReplacementSize; @@ -3461,7 +3482,7 @@ MethodTableBuilder::EnumerateClassMethods() _ASSERTE((cMemberSignature - originalRemainingSigOffset) == (cAsyncThunkMemberSignature - newRemainingSigOffset)); memcpy(pNewMemberSignature + newRemainingSigOffset, pMemberSignature + originalRemainingSigOffset, cMemberSignature - originalRemainingSigOffset); - if (returnKind == MethodReturnKind::NonGenericTaskReturningMethod) + if (returnKind == MethodReturnKind::NonGenericTaskReturningMethod || insertCount == 2) { pNewMemberSignature[newRemainingSigOffset - 1] = ELEMENT_TYPE_VOID; } @@ -3487,9 +3508,6 @@ MethodTableBuilder::EnumerateClassMethods() asyncVariantType, implType); - pNewMethod->SetAsyncOtherVariant(pDeclaredMethod); - pDeclaredMethod->SetAsyncOtherVariant(pNewMethod); - #ifdef FEATURE_COMINTEROP // We only ever include one of the two async variants (whichever doesn't have the async calling convention) // Record an excluded method here in the COM VTable. @@ -3522,6 +3540,23 @@ MethodTableBuilder::EnumerateClassMethods() { break; } + + // In rare cases we need a void-returning async variant in addition to the T-returning one. + // It is ok to add a void-returning thunk and end up not using it, but we want to avoid waste. + // Thus we try to filter closer to the cases when the thunk most certainly will be used. + if (insertCount == 1) + { + if (!bmtMetaData->fHasCovariantOverride || + implType != METHOD_IMPL || + returnsValueTask || + returnKind != MethodReturnKind::GenericTaskReturningMethod || + this->IsValueClass() || + !IsMdVirtual(dwMemberAttrs)) + { + // No need for another variant + break; + } + } } } @@ -5546,8 +5581,8 @@ MethodTableBuilder::PlaceVirtualMethods() } } -// Given an interface map entry, and a name+signature, compute the method on the interface -// that the name+signature corresponds to. Used by ProcessMethodImpls and ProcessInexactMethodImpls +// Given an interface map entry, and a name+signature+variantLookup, compute the method on the interface +// that the name+signature+variantLookup corresponds to. Used by ProcessMethodImpls and ProcessInexactMethodImpls // Always returns the first match that it finds. Affects the ambiguities in code:#ProcessInexactMethodImpls_Ambiguities MethodTableBuilder::bmtMethodHandle MethodTableBuilder::FindDeclMethodOnInterfaceEntry(bmtInterfaceEntry *pItfEntry, MethodSignature &declSig, AsyncVariantLookup variantLookup, bool searchForStaticMethods) @@ -5587,7 +5622,10 @@ MethodTableBuilder::FindDeclMethodOnInterfaceEntry(bmtInterfaceEntry *pItfEntry, } } - if (variantLookup == AsyncVariantLookup::AsyncOtherVariant && !declMethod.IsNull()) + // declSig is for an ordinary method, we should not find an async variant. + _ASSERTE(declMethod.IsNull() || !declMethod.GetMethodDesc()->IsAsyncVariantMethod()); + + if (variantLookup != AsyncVariantLookup::Ordinary && !declMethod.IsNull()) { bmtRTMethod* declRTMethod = declMethod.AsRTMethod(); // Other variant may not exist. For example we return Task and the base is generic and returns T. @@ -5600,7 +5638,7 @@ MethodTableBuilder::FindDeclMethodOnInterfaceEntry(bmtInterfaceEntry *pItfEntry, if ((slotDeclMethod->GetOwningType() == declRTMethod->GetOwningType()) && (slotDeclMethod->GetMethodDesc()->GetMethodTable() == declRTMethod->GetMethodDesc()->GetMethodTable()) && (slotDeclMethod->GetMethodDesc()->GetMemberDef() == declRTMethod->GetMethodDesc()->GetMemberDef()) && - (slotDeclMethod->GetMethodDesc()->IsAsyncVariantMethod() != declRTMethod->GetMethodDesc()->IsAsyncVariantMethod())) + (slotDeclMethod->GetMethodDesc()->MatchesAsyncVariantLookup(variantLookup))) { declMethod = slotIt->Decl(); break; @@ -5676,9 +5714,9 @@ MethodTableBuilder::ProcessInexactMethodImpls() continue; } - AsyncVariantLookup asyncVariantOfDeclToFind = !it->IsAsyncVariant() ? - AsyncVariantLookup::MatchingAsyncVariant : - AsyncVariantLookup::AsyncOtherVariant; + AsyncVariantLookup asyncVariantOfDeclToFind = it->IsAsyncVariant() ? + AsyncVariantLookup::Async : + AsyncVariantLookup::Ordinary; // If this method serves as the BODY of a MethodImpl specification, then // we should iterate all the MethodImpl's for this class and see just how many @@ -5821,9 +5859,9 @@ MethodTableBuilder::ProcessMethodImpls() continue; } - AsyncVariantLookup asyncVariantOfDeclToFind = !it->IsAsyncVariant() ? - AsyncVariantLookup::MatchingAsyncVariant : - AsyncVariantLookup::AsyncOtherVariant; + AsyncVariantLookup asyncVariantOfDeclToFind = it->IsAsyncVariant() ? + AsyncVariantLookup::Async : + AsyncVariantLookup::Ordinary; // If this method serves as the BODY of a MethodImpl specification, then // we should iterate all the MethodImpl's for this class and see just how many @@ -6008,11 +6046,23 @@ MethodTableBuilder::ProcessMethodImpls() declMethod = FindDeclMethodOnClassInHierarchy(it, pDeclMT, declSig, asyncVariantOfDeclToFind); } - if (declMethod.IsNull() && asyncVariantOfDeclToFind == AsyncVariantLookup::AsyncOtherVariant) + if (asyncVariantOfDeclToFind == AsyncVariantLookup::Async && + (declMethod.IsNull() || + !MethodSignature::SignaturesEquivalent(declMethod.GetMethodSignature(), it->GetMethodSignature(), FALSE))) { - // when implementing/overriding, we may see a Task-returning method - // which matches a T-returning method in the interface/base, which would not have variants. - // in such case the async variant of the Task-returning method does not implement/override anything. + // There are two scenarios when an async variant may not find a base to override: + // + // 1. We have a Task-returning method that is Task-returning due to generic substitution of the return type. + // The base method is T-returning and thus does not have an async variant that we can override. + // + // 2. We may have added a void-returning async thunk in anticipation of covariant Task -> Task override. + // The thunk is added very early based on limited type system information and it is not 100% guaranteed that + // we actually have Task -> Task situation. (i.e. we may have Object -> Task override or some other case...) + // When this happens the thunk does not override anything. + // + // It is ok in the above cases to not have a base. It means that the "impl" method should not be called + // polymorphically. + // continue; } @@ -6150,11 +6200,14 @@ MethodTableBuilder::bmtMethodHandle MethodTableBuilder::FindDeclMethodOnClassInH FALSE, iPass == 0 ? &newVisited : NULL)) { - if (variantLookup == AsyncVariantLookup::AsyncOtherVariant) + // We should find the ordinary variant first. + _ASSERTE(pCurMD->MatchesAsyncVariantLookup(AsyncVariantLookup::Ordinary)); + + if (variantLookup != AsyncVariantLookup::Ordinary) { - if (pCurMD->HasAsyncOtherVariant()) + if (pCurMD->ReturnsTaskOrValueTask()) { - pCurMD = pCurMD->GetAsyncOtherVariant(); + pCurMD = pCurMD->GetAsyncVariant(); } else { diff --git a/src/coreclr/vm/methodtablebuilder.h b/src/coreclr/vm/methodtablebuilder.h index 05d0d9c475fc8c..fece0cc7c6f83f 100644 --- a/src/coreclr/vm/methodtablebuilder.h +++ b/src/coreclr/vm/methodtablebuilder.h @@ -1112,9 +1112,6 @@ class MethodTableBuilder return m_asyncMethodFlags; } - bmtMDMethod * GetAsyncOtherVariant() const { return m_asyncOtherVariant; } - void SetAsyncOtherVariant(bmtMDMethod* pAsyncOtherVariant) { m_asyncOtherVariant = pAsyncOtherVariant; } - private: //----------------------------------------------------------------------------------------- bmtMDType * m_pOwningType; @@ -1126,7 +1123,6 @@ class MethodTableBuilder AsyncMethodFlags m_asyncMethodFlags; METHOD_IMPL_TYPE m_implType; // Whether or not the method is a methodImpl body MethodSignature m_methodSig; - bmtMDMethod* m_asyncOtherVariant = NULL; MethodDesc * m_pMD; // MethodDesc created and assigned to this method MethodDesc * m_pUnboxedMD; // Unboxing MethodDesc if this is a virtual method on a valuetype @@ -2027,11 +2023,7 @@ class MethodTableBuilder if ((*this)[i]->GetMethodSignature().GetToken() == tok) { auto result = (*this)[i]; - if (variantLookup == AsyncVariantLookup::AsyncOtherVariant) - { - return result->GetAsyncOtherVariant(); - } - else + if ((variantLookup == AsyncVariantLookup::Async) == result->IsAsyncVariant()) { return result; } diff --git a/src/coreclr/vm/runtimehandles.cpp b/src/coreclr/vm/runtimehandles.cpp index e0fddcae134d7f..da8093d38779e7 100644 --- a/src/coreclr/vm/runtimehandles.cpp +++ b/src/coreclr/vm/runtimehandles.cpp @@ -1947,7 +1947,7 @@ extern "C" MethodDesc* QCALLTYPE RuntimeMethodHandle_GetStubIfNeededSlow(MethodD if (pMethod->IsAsyncVariantMethod()) { // do not report async variants to reflection. - pMethod = pMethod->GetAsyncOtherVariant(/*allowInstParam*/ false); + pMethod = pMethod->GetOrdinaryVariant(/*allowInstParam*/ false); } TypeHandle instType = declaringTypeHandle.AsTypeHandle(); diff --git a/src/coreclr/vm/stubmgr.cpp b/src/coreclr/vm/stubmgr.cpp index b5161c3f042a66..4be4e51389c50d 100644 --- a/src/coreclr/vm/stubmgr.cpp +++ b/src/coreclr/vm/stubmgr.cpp @@ -2281,8 +2281,17 @@ BOOL AsyncThunkStubManager::TraceManager(Thread *thread, MethodDesc* pMD = NonVirtualEntry2MethodDesc(stubIP); if (pMD->IsAsyncThunkMethod()) { - MethodDesc* pOtherMD = pMD->GetAsyncOtherVariant(); - _ASSERTE_MSG(pOtherMD != NULL, "ATSM::TraceManager: Async thunk has no non-thunk variant to step through to"); + MethodDesc* pOtherMD = pMD->GetOrdinaryVariant(); + _ASSERTE_MSG(pOtherMD != NULL, "ATSM::TraceManager: Async thunk does not have non-async variant"); + + // An ordinary variant may be a thunk in a rare case when we start from ReturnDroppingThunk. + // In such case the regular async variant must not be a thunk. + if (pOtherMD->IsAsyncThunkMethod()) + { + pOtherMD = pMD->GetAsyncVariant(); + _ASSERTE_MSG(pOtherMD != NULL, "ATSM::TraceManager: Async thunk has no non-thunk variant to step through to"); + _ASSERTE(!pOtherMD->IsAsyncThunkMethod()); + } LOG((LF_CORDB, LL_INFO1000, "ATSM::TraceManager: Step through async thunk to target - %p\n", pOtherMD)); PCODE target = GetStubTarget(pOtherMD); diff --git a/src/coreclr/vm/zapsig.cpp b/src/coreclr/vm/zapsig.cpp index 09d9800df3e143..d73e609a01936d 100644 --- a/src/coreclr/vm/zapsig.cpp +++ b/src/coreclr/vm/zapsig.cpp @@ -917,7 +917,7 @@ MethodDesc *ZapSig::DecodeMethod(ModuleBase *pInfoModule, // This must be called even if nargs == 0, in order to create an instantiating - // stub for static methods in generic classees if needed, also for BoxedEntryPointStubs + // stub for static methods in generic classes if needed, also for BoxedEntryPointStubs // in non-generic structs. BOOL isInstantiatingStub = (methodFlags & ENCODE_METHOD_SIG_InstantiatingStub); BOOL isUnboxingStub = (methodFlags & ENCODE_METHOD_SIG_UnboxingStub); @@ -927,9 +927,9 @@ MethodDesc *ZapSig::DecodeMethod(ModuleBase *pInfoModule, isUnboxingStub, inst, !(isInstantiatingStub || isUnboxingStub) && !actualOwnerRequired, + isAsyncVariant ? AsyncVariantLookup::Async : AsyncVariantLookup::Ordinary, actualOwnerRequired, - TRUE, - isAsyncVariant == pMethod->IsAsyncVariantMethod() ? AsyncVariantLookup::MatchingAsyncVariant : AsyncVariantLookup::AsyncOtherVariant); + TRUE); if (methodFlags & ENCODE_METHOD_SIG_Constrained) { diff --git a/src/tests/async/covariant-return/covariant-returns.cs b/src/tests/async/covariant-return/covariant-returns.cs new file mode 100644 index 00000000000000..9aaef5f91f7122 --- /dev/null +++ b/src/tests/async/covariant-return/covariant-returns.cs @@ -0,0 +1,178 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; +using Xunit; + +public class CovariantReturns +{ + [Fact] + public static void Test0EntryPoint() + { + Test0().Wait(); + } + + [Fact] + public static void Test1EntryPoint() + { + Test1().Wait(); + } + + [Fact] + public static void Test2EntryPoint() + { + Test2().Wait(); + } + + [Fact] + public static void Test2AEntryPoint() + { + Test2A().Wait(); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static async Task Test0() + { + Base b = new Base(); + await b.M1(); + Assert.Equal("Base.M1;", b.Trace); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static async Task Test1() + { + // check year to not be concerned with devirtualization. + Base b = DateTime.Now.Year > 0 ? new Derived() : new Base(); + await b.M1(); + Assert.Equal("Derived.M1;Base.M1;", b.Trace); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static async Task Test2() + { + Base b = DateTime.Now.Year > 0 ? new Derived2() : new Base(); + await b.M1(); + Assert.Equal("Derived2.M1;Derived.M1;Base.M1;", b.Trace); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static async Task Test2A() + { + Base b = DateTime.Now.Year > 0 ? new Derived2A() : new Base(); + await b.M1(); + Assert.Equal("Derived2A.M1;DerivedA.M1;Base.M1;", b.Trace); + } + + struct S1 + { + public Guid guid; + public int num; + + public S1(int num) + { + this.guid = Guid.NewGuid(); + this.num = num; + } + } + + class Base + { + public string Trace; + public virtual Task M1() + { + Trace += "Base.M1;"; + return Task.CompletedTask; + } + } + + class Derived : Base + { + public override Task M1() + { + Trace += "Derived.M1;"; + base.M1().GetAwaiter().GetResult(); + return Task.FromResult(new S1(42)); + } + } + + class Derived2 : Derived + { + public override async Task M1() + { + Trace += "Derived2.M1;"; + await Task.Delay(1); + await base.M1(); + return new S1(4242); + } + } + + class DerivedA : Base + { + public async override Task M1() + { + Trace += "DerivedA.M1;"; + await base.M1(); + return new S1(42); + } + } + + class Derived2A : DerivedA + { + public override async Task M1() + { + Trace += "Derived2A.M1;"; + await Task.Delay(1); + await base.M1(); + return new S1(4242); + } + } +} + +namespace AsyncMicro +{ + public class Program + { + internal static string Trace; + + [Fact] + public static void TestPrRepro() + { + Derived2 test = new(); + Test(test).GetAwaiter().GetResult(); + Assert.Equal("Task Derived2.Foo;Task Derived.Foo;", Trace); + } + + private static async Task Test(Base b) + { + await b.Foo(); + } + + public class Base + { + public virtual async Task Foo() + { + Trace += "Task Base.Foo;"; + } + } + + public class Derived : Base + { + public override async Task Foo() + { + Trace += "Task Derived.Foo;"; + return 123; + } + } + + public class Derived2 : Derived + { + public override async Task Foo() + { + Trace += "Task Derived2.Foo;"; + return await base.Foo(); + } + } + } +} diff --git a/src/tests/async/covariant-return/covariant-returns.csproj b/src/tests/async/covariant-return/covariant-returns.csproj new file mode 100644 index 00000000000000..4f69c773ab13b1 --- /dev/null +++ b/src/tests/async/covariant-return/covariant-returns.csproj @@ -0,0 +1,9 @@ + + + + true + + + + +