From 76d4b7df3ec9324616a580c099e49d7256b83fea Mon Sep 17 00:00:00 2001 From: Aaron Robinson Date: Sat, 13 Apr 2019 10:43:18 -0700 Subject: [PATCH 1/5] Initial blocking for default interfaces in COM scenarios --- src/dlls/mscorrc/mscorrc.rc | 3 ++- src/dlls/mscorrc/resource.h | 2 +- src/vm/class.h | 14 +++++++++++++- src/vm/comcallablewrapper.cpp | 7 ++++++- src/vm/methodtable.cpp | 8 ++++++++ src/vm/methodtable.h | 3 +++ src/vm/methodtablebuilder.cpp | 5 +++++ src/vm/stdinterfaces.cpp | 13 ++++++++++++- .../DefaultInterfacesTests.cpp | 18 +++--------------- 9 files changed, 53 insertions(+), 20 deletions(-) diff --git a/src/dlls/mscorrc/mscorrc.rc b/src/dlls/mscorrc/mscorrc.rc index c733f694ee76..3576b48b0471 100644 --- a/src/dlls/mscorrc/mscorrc.rc +++ b/src/dlls/mscorrc/mscorrc.rc @@ -913,8 +913,9 @@ BEGIN IDS_EE_COM_INVISIBLE_PARENT "Type '%1' has a ComVisible(false) parent '%2' in its hierarchy, therefore QueryInterface calls for IDispatch or class interfaces are disallowed." - IDS_EE_COMIMPORT_METHOD_NO_INTERFACE "Method '%1' in ComImport class '%2' must implement an interface method." IDS_EE_ATTEMPT_TO_CREATE_GENERIC_CCW "Generic types cannot be marshaled to COM interface pointers." + IDS_EE_ATTEMPT_TO_CREATE_NON_ABSTRACT_CCW "Types with non-abstract methods cannot be marshaled to COM interface pointers." + IDS_EE_COMIMPORT_METHOD_NO_INTERFACE "Method '%1' in ComImport class '%2' must implement an interface method." IDS_CLASSLOAD_BAD_METHOD_COUNT "Metadata method count does not match method enumeration length for type '%1' from assembly '%2'." IDS_CLASSLOAD_BAD_FIELD_COUNT "Metadata field count does not match field enumeration length for type '%1' from assembly '%2'." diff --git a/src/dlls/mscorrc/resource.h b/src/dlls/mscorrc/resource.h index f11aa12fecb6..01982d324c59 100644 --- a/src/dlls/mscorrc/resource.h +++ b/src/dlls/mscorrc/resource.h @@ -440,7 +440,7 @@ #define IDS_EE_PROFILING_FAILURE 0x1aa8 #define IDS_EE_ATTEMPT_TO_CREATE_GENERIC_CCW 0x1aa9 - +#define IDS_EE_ATTEMPT_TO_CREATE_NON_ABSTRACT_CCW 0x1aaa #define IDS_EE_COMIMPORT_METHOD_NO_INTERFACE 0x1aab #define IDS_EE_OUT_OF_MEMORY_WITHIN_RANGE 0x1aac #define IDS_EE_ARRAY_DIMENSIONS_EXCEEDED 0x1aad diff --git a/src/vm/class.h b/src/vm/class.h index 2853aee330e2..3f678430f2b3 100644 --- a/src/vm/class.h +++ b/src/vm/class.h @@ -1349,6 +1349,11 @@ class EEClass // DO NOT CREATE A NEW EEClass USING NEW! LIMITED_METHOD_CONTRACT; m_VMFlags |= (DWORD) VMFLAG_FIXED_ADDRESS_VT_STATICS; } + void SetHasOnlyAbstractMethods() + { + LIMITED_METHOD_CONTRACT; + m_VMFlags |= (DWORD) VMFLAG_ONLY_ABSTRACT_METHODS; + } #ifdef FEATURE_COMINTEROP void SetSparseForCOMInterop() { @@ -1430,6 +1435,13 @@ class EEClass // DO NOT CREATE A NEW EEClass USING NEW! LIMITED_METHOD_CONTRACT; return m_VMFlags & VMFLAG_FIXED_ADDRESS_VT_STATICS; } + + BOOL HasOnlyAbstractMethods() + { + LIMITED_METHOD_CONTRACT; + return m_VMFlags & VMFLAG_ONLY_ABSTRACT_METHODS; + } + #ifdef FEATURE_COMINTEROP BOOL IsSparseForCOMInterop() { @@ -1859,7 +1871,7 @@ class EEClass // DO NOT CREATE A NEW EEClass USING NEW! // unused = 0x00080000, VMFLAG_CONTAINS_STACK_PTR = 0x00100000, VMFLAG_PREFER_ALIGN8 = 0x00200000, // Would like to have 8-byte alignment - // unused = 0x00400000, + VMFLAG_ONLY_ABSTRACT_METHODS = 0x00400000, // Type only contains abstract methods #ifdef FEATURE_COMINTEROP VMFLAG_SPARSE_FOR_COMINTEROP = 0x00800000, diff --git a/src/vm/comcallablewrapper.cpp b/src/vm/comcallablewrapper.cpp index 3eb985120664..908cf8315572 100644 --- a/src/vm/comcallablewrapper.cpp +++ b/src/vm/comcallablewrapper.cpp @@ -3452,7 +3452,12 @@ IUnknown* ComCallWrapper::GetComIPFromCCW(ComCallWrapper *pWrap, REFIID riid, Me { COMPlusThrow(kInvalidOperationException, IDS_EE_ATTEMPT_TO_CREATE_GENERIC_CCW); } - + + if (pIntfMT->IsInterface() && !pIntfMT->HasOnlyAbstractMethods()) + { + COMPlusThrow(kInvalidOperationException, IDS_EE_ATTEMPT_TO_CREATE_NON_ABSTRACT_CCW); + } + // The first block has one slot for the IClassX vtable pointer // and one slot for the basic vtable pointer. imapIndex += Slot_FirstInterface; diff --git a/src/vm/methodtable.cpp b/src/vm/methodtable.cpp index 34381fa2aad2..a5f33b376849 100644 --- a/src/vm/methodtable.cpp +++ b/src/vm/methodtable.cpp @@ -8221,6 +8221,14 @@ DWORD MethodTable::HasFixedAddressVTStatics() return GetClass()->HasFixedAddressVTStatics(); } +//========================================================================================== +BOOL MethodTable::HasOnlyAbstractMethods() +{ + LIMITED_METHOD_CONTRACT; + + return GetClass()->HasOnlyAbstractMethods(); +} + //========================================================================================== WORD MethodTable::GetNumHandleRegularStatics() { diff --git a/src/vm/methodtable.h b/src/vm/methodtable.h index 74febebc39bc..9c128436a7a5 100644 --- a/src/vm/methodtable.h +++ b/src/vm/methodtable.h @@ -2572,6 +2572,9 @@ class MethodTable DWORD HasFixedAddressVTStatics(); + // Indicates if the MethodTable only contains abstract methods + BOOL HasOnlyAbstractMethods(); + //------------------------------------------------------------------- // PER-INSTANTIATION STATICS INFO // diff --git a/src/vm/methodtablebuilder.cpp b/src/vm/methodtablebuilder.cpp index 286cd74c0a27..6cff19266f5c 100644 --- a/src/vm/methodtablebuilder.cpp +++ b/src/vm/methodtablebuilder.cpp @@ -3230,6 +3230,11 @@ MethodTableBuilder::EnumerateClassMethods() } } + if (bmtMethod->dwNumDeclaredNonAbstractMethods == 0) + { + GetHalfBakedClass()->SetHasOnlyAbstractMethods(); + } + // Check to see that we have all of the required delegate methods (ECMA 13.6 Delegates) if (IsDelegate()) { diff --git a/src/vm/stdinterfaces.cpp b/src/vm/stdinterfaces.cpp index 15da1c8c2210..e050f2f66b70 100644 --- a/src/vm/stdinterfaces.cpp +++ b/src/vm/stdinterfaces.cpp @@ -198,7 +198,18 @@ Unknown_QueryInterface_Internal(ComCallWrapper* pWrap, IUnknown* pUnk, REFIID ri // If we haven't found the IP or if we haven't looked yet (because we aren't // being aggregated), now look on the managed object to see if it supports the interface. if (pDestItf == NULL) - pDestItf = ComCallWrapper::GetComIPFromCCW(pWrap, riid, NULL, GetComIPFromCCW::CheckVisibility); + { + EX_TRY + { + pDestItf = ComCallWrapper::GetComIPFromCCW(pWrap, riid, NULL, GetComIPFromCCW::CheckVisibility); + } + EX_CATCH + { + Exception *e = GET_EXCEPTION(); + hr = e->GetHR(); + } + EX_END_CATCH(RethrowTerminalExceptions) + } ErrExit: // If we succeeded in obtaining the requested IP then return S_OK. diff --git a/tests/src/Interop/COM/NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp b/tests/src/Interop/COM/NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp index 9d87215049e5..7270df8f83f2 100644 --- a/tests/src/Interop/COM/NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp +++ b/tests/src/Interop/COM/NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp @@ -64,20 +64,8 @@ void CallDefaultInterface() HRESULT hr; ComSmartPtr defInterface; - THROW_IF_FAILED(::CoCreateInstance(CLSID_DefaultInterfaceTesting, nullptr, CLSCTX_INPROC, IID_IDefaultInterfaceTesting, (void**)&defInterface)); + hr = ::CoCreateInstance(CLSID_DefaultInterfaceTesting, nullptr, CLSCTX_INPROC, IID_IDefaultInterfaceTesting, (void**)&defInterface) - int i; - - THROW_IF_FAILED(defInterface->DefOnInterfaceRet2(&i)); - THROW_FAIL_IF_FALSE(i == 2); - - THROW_IF_FAILED(defInterface->DefOnClassRet3(&i)); - THROW_FAIL_IF_FALSE(i == 3); - - // - // Overridden default interface defintions do not work - // https://github.com/dotnet/coreclr/issues/15683 - // - //THROW_IF_FAILED(defInterface->DefOnInterface2Ret5(&i)); - //THROW_FAIL_IF_FALSE(i == 5); + const int COR_E_INVALIDOPERATION = 0x80131509; + THROW_FAIL_IF_FALSE(hr == COR_E_INVALIDOPERATION); } From d44e0bb6db768153bf4f67a9420e59e397494e9e Mon Sep 17 00:00:00 2001 From: Aaron Robinson Date: Sat, 13 Apr 2019 11:58:05 -0700 Subject: [PATCH 2/5] Update DefaultInterfacesTests.cpp --- .../NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/src/Interop/COM/NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp b/tests/src/Interop/COM/NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp index 7270df8f83f2..357bea6408e2 100644 --- a/tests/src/Interop/COM/NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp +++ b/tests/src/Interop/COM/NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp @@ -64,7 +64,7 @@ void CallDefaultInterface() HRESULT hr; ComSmartPtr defInterface; - hr = ::CoCreateInstance(CLSID_DefaultInterfaceTesting, nullptr, CLSCTX_INPROC, IID_IDefaultInterfaceTesting, (void**)&defInterface) + hr = ::CoCreateInstance(CLSID_DefaultInterfaceTesting, nullptr, CLSCTX_INPROC, IID_IDefaultInterfaceTesting, (void**)&defInterface); const int COR_E_INVALIDOPERATION = 0x80131509; THROW_FAIL_IF_FALSE(hr == COR_E_INVALIDOPERATION); From 00dc2aa288b1b3dbda098f1c91262e9a939def3e Mon Sep 17 00:00:00 2001 From: Aaron Robinson Date: Wed, 17 Apr 2019 19:50:26 -0700 Subject: [PATCH 3/5] Validate interface during IClassFactory::CreateInstance() Add additional test cases and move some CCW functions to a more consolidated layout. --- .../Runtime/InteropServices/ComActivator.cs | 31 +- src/vm/comcallablewrapper.cpp | 585 +++++++++--------- src/vm/comcallablewrapper.h | 59 +- .../NETServer/NETServer.DefaultInterfaces.il | 22 +- .../DefaultInterfacesTests.cpp | 51 +- .../COM/ServerContracts/Server.Contracts.h | 6 +- 6 files changed, 387 insertions(+), 367 deletions(-) diff --git a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs index d1c6aa9bb7a4..43f3587a6f6b 100644 --- a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs +++ b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs @@ -235,12 +235,12 @@ public BasicClassFactory(Guid clsid, Type classType) _classType = classType; } - public static void ValidateInterfaceRequest(Type classType, ref Guid riid, object outer) + public static Type GetValidatedInterfaceType(Type classType, ref Guid riid, object outer) { Debug.Assert(classType != null); if (riid == Marshal.IID_IUnknown) { - return; + return typeof(object); } // Aggregation can only be done when requesting IUnknown. @@ -250,22 +250,27 @@ public static void ValidateInterfaceRequest(Type classType, ref Guid riid, objec throw new COMException(string.Empty, CLASS_E_NOAGGREGATION); } - bool found = false; - // Verify the class implements the desired interface foreach (Type i in classType.GetInterfaces()) { if (i.GUID == riid) { - found = true; - break; + return i; } } - if (!found) + // E_NOINTERFACE + throw new InvalidCastException(); + } + + public static void ValidateInterfaceIsMarshallable(object obj, Type interfaceType) + { + Debug.Assert(obj != null && interfaceType != null); + + if (interfaceType != typeof(object)) { - // E_NOINTERFACE - throw new InvalidCastException(); + Debug.Assert(interfaceType.IsInterface); + Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore); } } @@ -291,13 +296,15 @@ public void CreateInstance( ref Guid riid, [MarshalAs(UnmanagedType.Interface)] out object ppvObject) { - BasicClassFactory.ValidateInterfaceRequest(_classType, ref riid, pUnkOuter); + Type interfaceType = BasicClassFactory.GetValidatedInterfaceType(_classType, ref riid, pUnkOuter); ppvObject = Activator.CreateInstance(_classType); if (pUnkOuter != null) { ppvObject = BasicClassFactory.CreateAggregatedObject(pUnkOuter, ppvObject); } + + BasicClassFactory.ValidateInterfaceIsMarshallable(ppvObject, interfaceType); } public void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock) @@ -368,13 +375,15 @@ private void CreateInstanceInner( bool isDesignTime, out object ppvObject) { - BasicClassFactory.ValidateInterfaceRequest(_classType, ref riid, pUnkOuter); + Type interfaceType = BasicClassFactory.GetValidatedInterfaceType(_classType, ref riid, pUnkOuter); ppvObject = _licenseProxy.AllocateAndValidateLicense(_classType, key, isDesignTime); if (pUnkOuter != null) { ppvObject = BasicClassFactory.CreateAggregatedObject(pUnkOuter, ppvObject); } + + BasicClassFactory.ValidateInterfaceIsMarshallable(ppvObject, interfaceType); } } } diff --git a/src/vm/comcallablewrapper.cpp b/src/vm/comcallablewrapper.cpp index 908cf8315572..39cd0a14cfe3 100644 --- a/src/vm/comcallablewrapper.cpp +++ b/src/vm/comcallablewrapper.cpp @@ -159,6 +159,9 @@ void DestructComCallMethodDescs(ArrayList *pDescArray) typedef Wrapper, DestructComCallMethodDescs> ComCallMethodDescArrayHolder; +// Forward declarations +static bool GetComIPFromCCW_HandleCustomQI(ComCallWrapper *pWrap, REFIID riid, MethodTable * pIntfMT, IUnknown **ppUnkOut); + //-------------------------------------------------------------------------- // IsDuplicateClassItfMD(MethodDesc *pMD, unsigned int ix) // Determines if the specified method desc is a duplicate. @@ -1277,7 +1280,7 @@ BOOL SimpleComCallWrapper::CustomQIRespondsToIMarshal() DWORD newFlags = enum_CustomQIRespondsToIMarshal_Inited; SafeComHolder pUnk; - if (ComCallWrapper::GetComIPFromCCW_HandleCustomQI(GetMainWrapper(), IID_IMarshal, NULL, &pUnk)) + if (GetComIPFromCCW_HandleCustomQI(GetMainWrapper(), IID_IMarshal, NULL, &pUnk)) { newFlags |= enum_CustomQIRespondsToIMarshal; } @@ -2314,6 +2317,37 @@ ComCallWrapper* ComCallWrapper::CopyFromTemplate(ComCallWrapperTemplate* pTempla RETURN pStartWrapper; } +//-------------------------------------------------------------------------- +// identify the location within the wrapper where the vtable for this index will +// be stored +//-------------------------------------------------------------------------- +SLOT** ComCallWrapper::GetComIPLocInWrapper(ComCallWrapper* pWrap, unsigned int iIndex) +{ + CONTRACT (SLOT**) + { + NOTHROW; + GC_NOTRIGGER; + MODE_ANY; + PRECONDITION(CheckPointer(pWrap)); + PRECONDITION(iIndex > 1); // We should never attempt to get the basic or IClassX interface here. + POSTCONDITION(CheckPointer(RETVAL)); + } + CONTRACT_END; + + SLOT** pTearOff = NULL; + while (iIndex >= NumVtablePtrs) + { + //@todo delayed creation support + _ASSERTE(pWrap->IsLinked() != 0); + pWrap = GetNext(pWrap); + iIndex-= NumVtablePtrs; + } + _ASSERTE(pWrap != NULL); + pTearOff = (SLOT **)&pWrap->m_rgpIPtr[iIndex]; + + RETURN pTearOff; +} + //-------------------------------------------------------------------------- // void ComCallWrapper::Cleanup(ComCallWrapper* pWrap) // clean up , release gc registered reference and free wrapper @@ -2471,7 +2505,6 @@ void ComCallWrapper::Neuter() } } - //-------------------------------------------------------------------------- // void ComCallWrapper::ClearHandle() // clear the ref-counted handle @@ -2487,6 +2520,20 @@ void ComCallWrapper::ClearHandle() } } +SLOT** ComCallWrapper::GetFirstInterfaceSlot() +{ + CONTRACT(SLOT**) + { + NOTHROW; + GC_NOTRIGGER; + MODE_ANY; + POSTCONDITION(CheckPointer(RETVAL)); + } + CONTRACT_END; + + SLOT** firstInterface = GetComIPLocInWrapper(this, Slot_FirstInterface); + RETURN firstInterface; +} //-------------------------------------------------------------------------- // void ComCallWrapper::FreeWrapper(ComCallWrapper* pWrap) @@ -2658,67 +2705,6 @@ ComCallWrapper* ComCallWrapper::CreateWrapper(OBJECTREF* ppObj, ComCallWrapperTe RETURN pStartWrapper; } - -//-------------------------------------------------------------------------- -// signed ComCallWrapper::GetIndexForIntfMT(ComCallWrapperTemplate *pTemplate, MethodTable *pIntfMT) -// check if the interface is supported, return a index into the IMap -// returns -1, if pIntfMT is not supported -//-------------------------------------------------------------------------- -signed ComCallWrapper::GetIndexForIntfMT(ComCallWrapperTemplate *pTemplate, MethodTable *pIntfMT) -{ - CONTRACTL - { - THROWS; - GC_TRIGGERS; - MODE_ANY; - PRECONDITION(CheckPointer(pTemplate)); - PRECONDITION(CheckPointer(pIntfMT)); - } - CONTRACTL_END; - - for (unsigned j = 0; j < pTemplate->GetNumInterfaces(); j++) - { - ComMethodTable *pItfComMT = (ComMethodTable *)pTemplate->GetVTableSlot(j) - 1; - if (pItfComMT->GetMethodTable()->IsEquivalentTo(pIntfMT)) - return j; - } - - // oops, iface not found - return -1; -} - -//-------------------------------------------------------------------------- -// SLOT** ComCallWrapper::GetComIPLocInWrapper(ComCallWrapper* pWrap, unsigned iIndex) -// identify the location within the wrapper where the vtable for this index will -// be stored -//-------------------------------------------------------------------------- -SLOT** ComCallWrapper::GetComIPLocInWrapper(ComCallWrapper* pWrap, unsigned iIndex) -{ - CONTRACT (SLOT**) - { - NOTHROW; - GC_NOTRIGGER; - MODE_ANY; - PRECONDITION(CheckPointer(pWrap)); - PRECONDITION(iIndex > 1); // We should never attempt to get the basic or IClassX interface here. - POSTCONDITION(CheckPointer(RETVAL)); - } - CONTRACT_END; - - SLOT** pTearOff = NULL; - while (iIndex >= NumVtablePtrs) - { - //@todo delayed creation support - _ASSERTE(pWrap->IsLinked() != 0); - pWrap = GetNext(pWrap); - iIndex-= NumVtablePtrs; - } - _ASSERTE(pWrap != NULL); - pTearOff = (SLOT **)&pWrap->m_rgpIPtr[iIndex]; - - RETURN pTearOff; -} - //-------------------------------------------------------------------------- // Get IClassX interface pointer from the wrapper. This method will also // lay out the IClassX COM method table if it has not yet been laid out. @@ -2886,8 +2872,7 @@ VOID __stdcall InvokeICustomQueryInterfaceGetInterface_CallBack(LPVOID ptr) } // Returns a covariant supertype of pMT with the given IID or NULL if not found. -// static -MethodTable *ComCallWrapper::FindCovariantSubtype(MethodTable *pMT, REFIID riid) +static MethodTable *FindCovariantSubtype(MethodTable *pMT, REFIID riid) { CONTRACTL { @@ -2958,21 +2943,121 @@ MethodTable *ComCallWrapper::FindCovariantSubtype(MethodTable *pMT, REFIID riid) return NULL; } -// Like GetComIPFromCCW, but will try to find riid/pIntfMT among interfaces implemented by this object that have variance. -IUnknown* ComCallWrapper::GetComIPFromCCWUsingVariance(REFIID riid, MethodTable* pIntfMT, GetComIPFromCCW::flags flags) +//-------------------------------------------------------------------------- +// check if the interface is supported, return a index into the IMap +// returns -1, if pIntfMT is not supported +//-------------------------------------------------------------------------- +static int GetIndexForIntfMT(ComCallWrapperTemplate *pTemplate, MethodTable *pIntfMT) { CONTRACTL { THROWS; GC_TRIGGERS; MODE_ANY; - PRECONDITION(GetComCallWrapperTemplate()->SupportsVariantInterface()); - PRECONDITION(!GetComCallWrapperTemplate()->RepresentsVariantInterface()); + PRECONDITION(CheckPointer(pTemplate)); + PRECONDITION(CheckPointer(pIntfMT)); + } + CONTRACTL_END; + + for (ULONG j = 0; j < pTemplate->GetNumInterfaces(); j++) + { + ComMethodTable *pItfComMT = (ComMethodTable *)pTemplate->GetVTableSlot(j) - 1; + if (pItfComMT->GetMethodTable()->IsEquivalentTo(pIntfMT)) + return j; + } + + return -1; +} + +static IUnknown *GetComIPFromCCW_VisibilityCheck( + IUnknown *pIntf, + MethodTable *pIntfMT, + ComMethodTable *pIntfComMT, + GetComIPFromCCW::flags flags) +{ + CONTRACT(IUnknown*) + { + THROWS; + GC_TRIGGERS; + MODE_ANY; + PRECONDITION(CheckPointer(pIntf)); + PRECONDITION(CheckPointer(pIntfComMT)); + } + CONTRACT_END; + + if (// Do a visibility check if needed. + ((flags & GetComIPFromCCW::CheckVisibility) && (!pIntfComMT->IsComVisible()))) + { + // If not, fail to return the interface. + SafeRelease(pIntf); + RETURN NULL; + } + RETURN pIntf; +} + +static IUnknown *GetComIPFromCCW_VariantInterface( + ComCallWrapper *pWrap, + REFIID riid, + MethodTable *pIntfMT, + GetComIPFromCCW::flags flags, + ComCallWrapperTemplate *pTemplate) +{ + CONTRACT(IUnknown*) + { + THROWS; + GC_TRIGGERS; + MODE_ANY; + PRECONDITION(CheckPointer(pWrap)); + } + CONTRACT_END; + + IUnknown* pIntf = NULL; + ComMethodTable *pIntfComMT = NULL; + + // we are only going to respond to the one interface that this CCW represents + pIntfComMT = pTemplate->GetComMTForIndex(0); + if (pIntfComMT->GetIID() == riid || (pIntfMT != NULL && GetIndexForIntfMT(pTemplate, pIntfMT) == 0)) + { + SLOT **ppVtable = pWrap->GetFirstInterfaceSlot(); + _ASSERTE(*ppVtable != NULL); // this should point to COM Vtable or interface vtable + + if (!pIntfComMT->IsLayoutComplete() && !pIntfComMT->LayOutInterfaceMethodTable(NULL)) + { + RETURN NULL; + } + + // The interface pointer is the pointer to the vtable. + pIntf = (IUnknown*)ppVtable; + + // AddRef the wrapper. + // Note that we don't do SafeAddRef(pIntf) because it's overkill to + // go via IUnknown when we already have the wrapper in-hand. + pWrap->AddRefWithAggregationCheck(); + + RETURN GetComIPFromCCW_VisibilityCheck(pIntf, pIntfMT, pIntfComMT, flags); + } + + // for anything else, fall back to the CCW representing the "parent" class + RETURN ComCallWrapper::GetComIPFromCCW(pWrap->GetSimpleWrapper()->GetClassWrapper(), riid, pIntfMT, flags); +} + +// Like GetComIPFromCCW, but will try to find riid/pIntfMT among interfaces implemented by this +// object that have variance. Assumes that call GetComIPFromCCW with same arguments has failed. +static IUnknown* GetComIPFromCCW_UsingVariance(ComCallWrapper *pWrap, REFIID riid, MethodTable* pIntfMT, GetComIPFromCCW::flags flags) +{ + CONTRACTL + { + THROWS; + GC_TRIGGERS; + MODE_ANY; + PRECONDITION(CheckPointer(pWrap)); + PRECONDITION(pWrap->GetComCallWrapperTemplate()->SupportsVariantInterface()); + PRECONDITION(!pWrap->GetComCallWrapperTemplate()->RepresentsVariantInterface()); } CONTRACTL_END; // try the fast per-ComCallWrapperTemplate cache first - ComCallWrapperTemplate::IIDToInterfaceTemplateCache *pCache = GetComCallWrapperTemplate()->GetOrCreateIIDToInterfaceTemplateCache(); + ComCallWrapperTemplate::IIDToInterfaceTemplateCache *pCache = pWrap->GetComCallWrapperTemplate()->GetOrCreateIIDToInterfaceTemplateCache(); GUID local_iid; const IID *piid = &riid; @@ -3027,7 +3112,7 @@ IUnknown* ComCallWrapper::GetComIPFromCCWUsingVariance(REFIID riid, MethodTable* // We'll perform a simplified check which is limited only to covariance with one generic parameter (luckily // all WinRT variant types currently fall into this bucket). // - TypeHandle thClass = GetComCallWrapperTemplate()->GetClassType(); + TypeHandle thClass = pWrap->GetComCallWrapperTemplate()->GetClassType(); ComCallWrapperTemplate::CCWInterfaceMapIterator it(thClass, NULL, false); while (it.Next()) @@ -3065,7 +3150,7 @@ IUnknown* ComCallWrapper::GetComIPFromCCWUsingVariance(REFIID riid, MethodTable* pVariantIntfMT = pIntfMT; } - TypeHandle thClass = GetComCallWrapperTemplate()->GetClassType(); + TypeHandle thClass = pWrap->GetComCallWrapperTemplate()->GetClassType(); if (pVariantIntfMT != NULL && thClass.CanCastTo(pVariantIntfMT)) { _ASSERTE_MSG(!thClass.GetMethodTable()->ImplementsInterface(pVariantIntfMT), "This should have been taken care of by GetComIPFromCCW"); @@ -3095,8 +3180,8 @@ IUnknown* ComCallWrapper::GetComIPFromCCWUsingVariance(REFIID riid, MethodTable* GCPROTECT_BEGIN(oref); { - oref = GetObjectRef(); - pCCW = InlineGetWrapper(&oref, pIntfTemplate, this); + oref = pWrap->GetObjectRef(); + pCCW = ComCallWrapper::InlineGetWrapper(&oref, pIntfTemplate, pWrap); } GCPROTECT_END(); } @@ -3109,35 +3194,130 @@ IUnknown* ComCallWrapper::GetComIPFromCCWUsingVariance(REFIID riid, MethodTable* return NULL; } -// static -inline IUnknown * ComCallWrapper::GetComIPFromCCW_VisibilityCheck( - IUnknown * pIntf, MethodTable * pIntfMT, ComMethodTable * pIntfComMT, - GetComIPFromCCW::flags flags) +static IUnknown * GetComIPFromCCW_HandleExtendsCOMObject( + ComCallWrapper * pWrap, + REFIID riid, + MethodTable * pIntfMT, + ComCallWrapperTemplate * pTemplate, + int imapIndex, + unsigned int intfIndex) +{ + CONTRACTL + { + THROWS; + GC_TRIGGERS; + MODE_ANY; + } + CONTRACTL_END; + + // If we don't implement the interface, we delegate to base + BOOL bDelegateToBase = TRUE; + if (imapIndex != -1) + { + MethodTable * pMT = pWrap->GetMethodTableOfObjectRef(); + + // Check if this index is actually an interface implemented by us + // if it belongs to the base COM guy then we can hand over the call + // to him + if (pMT->IsWinRTObjectType()) + { + bDelegateToBase = pTemplate->GetComMTForIndex(intfIndex)->IsWinRTTrivialAggregate(); + } + else + { + MethodTable::InterfaceMapIterator intIt = pMT->IterateInterfaceMapFrom(intfIndex); + + // If the number of slots is 0, then no need to proceed + if (intIt.GetInterface()->GetNumVirtuals() != 0) + { + MethodDesc *pClsMD = NULL; + + // Find the implementation for the first slot of the interface + DispatchSlot impl(pMT->FindDispatchSlot(intIt.GetInterface()->GetTypeID(), 0, FALSE /* throwOnConflict */)); + CONSISTENCY_CHECK(!impl.IsNull()); + + // Get the MethodDesc for this slot in the class + pClsMD = impl.GetMethodDesc(); + + MethodTable * pClsMT = pClsMD->GetMethodTable(); + bDelegateToBase = (pClsMT->IsInterface() || pClsMT->IsComImport()) ? TRUE : FALSE; + } + else + { + // The interface has no methods so we cannot override it. Because of this + // it makes sense to delegate to the base COM component. + bDelegateToBase = TRUE; + } + } + } + + if (bDelegateToBase) + { + // This is an interface of the base COM guy so delegate the call to him + SyncBlock* pBlock = pWrap->GetSyncBlock(); + _ASSERTE(pBlock); + + SafeComHolder pUnk; + + RCWHolder pRCW(GetThread()); + RCWPROTECT_BEGIN(pRCW, pBlock); + + pUnk = (pIntfMT != NULL) ? pRCW->GetComIPFromRCW(pIntfMT) + : pRCW->GetComIPFromRCW(riid); + + RCWPROTECT_END(pRCW); + return pUnk.Extract(); + } + + return NULL; +} + +static IUnknown * GetComIPFromCCW_ForIID_Worker( + ComCallWrapper *pWrap, + REFIID riid, + MethodTable *pIntfMT, + GetComIPFromCCW::flags flags, + ComCallWrapperTemplate * pTemplate) { CONTRACT(IUnknown*) { THROWS; GC_TRIGGERS; MODE_ANY; - PRECONDITION(CheckPointer(pIntf)); - PRECONDITION(CheckPointer(pIntfComMT)); + PRECONDITION(CheckPointer(pWrap)); + POSTCONDITION(CheckPointer(RETVAL, NULL_OK)); } CONTRACT_END; - if (// Do a visibility check if needed. - ((flags & GetComIPFromCCW::CheckVisibility) && (!pIntfComMT->IsComVisible()))) + ComMethodTable * pIntfComMT = NULL; + MethodTable * pMT = pWrap->GetMethodTableOfObjectRef(); + + // At this point, it must be that the IID is one of IClassX IIDs or + // it isn't implemented on this class. We'll have to search through and set + // up the entire hierarchy to determine which it is. + if (IsIClassX(pMT, riid, &pIntfComMT)) { - // If not, fail to return the interface. - SafeRelease(pIntf); - RETURN NULL; + // If the class that this IClassX's was generated for is marked + // as ClassInterfaceType.AutoDual or AutoDisp, or it is a WinRT + // delegate, then give out the IClassX IP. + if (pIntfComMT->GetClassInterfaceType() == clsIfAutoDual || pIntfComMT->GetClassInterfaceType() == clsIfAutoDisp || + pIntfComMT->IsWinRTDelegate()) + { + // Make sure the all the base classes of the class this IClassX corresponds to + // are visible to COM. + pIntfComMT->CheckParentComVisibility(FALSE); + + // Giveout IClassX of this class because the IID matches one of the IClassX in the hierarchy + // This assumes any IClassX implementation must be derived from base class IClassX's implementation + IUnknown * pIntf = pWrap->GetIClassXIP(); + RETURN GetComIPFromCCW_VisibilityCheck(pIntf, pIntfMT, pIntfComMT, flags); + } } - RETURN pIntf; + + RETURN NULL; } -// static -IUnknown * ComCallWrapper::GetComIPFromCCW_VariantInterface( - ComCallWrapper * pWrap, REFIID riid, MethodTable * pIntfMT, - GetComIPFromCCW::flags flags, ComCallWrapperTemplate * pTemplate) +static IUnknown *GetComIPFromCCW_ForIntfMT_Worker(ComCallWrapper *pWrap, MethodTable *pIntfMT, GetComIPFromCCW::flags flags) { CONTRACT(IUnknown*) { @@ -3145,41 +3325,47 @@ IUnknown * ComCallWrapper::GetComIPFromCCW_VariantInterface( GC_TRIGGERS; MODE_ANY; PRECONDITION(CheckPointer(pWrap)); + POSTCONDITION(CheckPointer(RETVAL, NULL_OK)); } CONTRACT_END; - IUnknown* pIntf = NULL; - ComMethodTable *pIntfComMT = NULL; + MethodTable * pMT = pWrap->GetMethodTableOfObjectRef(); - // we are only going to respond to the one interface that this CCW represents - pIntfComMT = pTemplate->GetComMTForIndex(0); - if (pIntfComMT->GetIID() == riid || (pIntfMT != NULL && GetIndexForIntfMT(pTemplate, pIntfMT) == 0)) + // class method table + if (pMT->CanCastToClass(pIntfMT)) { - SLOT **ppVtable = GetComIPLocInWrapper(pWrap, Slot_FirstInterface); - _ASSERTE(*ppVtable != NULL); // this should point to COM Vtable or interface vtable - - if (!pIntfComMT->IsLayoutComplete() && !pIntfComMT->LayOutInterfaceMethodTable(NULL)) + // Make sure we're not trying to pass out a generic-based class interface (except for WinRT delegates) + if (pMT->HasInstantiation() && !pMT->SupportsGenericInterop(TypeHandle::Interop_NativeToManaged)) { - RETURN NULL; + COMPlusThrow(kInvalidOperationException, IDS_EE_ATTEMPT_TO_CREATE_GENERIC_CCW); } - // The interface pointer is the pointer to the vtable. - pIntf = (IUnknown*)ppVtable; + // Retrieve the COM method table for the requested interface. + ComCallWrapperTemplate *pIntfCCWTemplate = ComCallWrapperTemplate::GetTemplate(TypeHandle(pIntfMT)); + if (pIntfCCWTemplate->SupportsIClassX()) + { + ComMethodTable * pIntfComMT = pIntfCCWTemplate->GetClassComMT(); - // AddRef the wrapper. - // Note that we don't do SafeAddRef(pIntf) because it's overkill to - // go via IUnknown when we already have the wrapper in-hand. - pWrap->AddRefWithAggregationCheck(); + // If the class that this IClassX's was generated for is marked + // as ClassInterfaceType.AutoDual or AutoDisp, or it is a WinRT + // delegate, then give out the IClassX IP. + if (pIntfComMT->GetClassInterfaceType() == clsIfAutoDual || pIntfComMT->GetClassInterfaceType() == clsIfAutoDisp || + pIntfComMT->IsWinRTDelegate()) + { + // Make sure the all the base classes of the class this IClassX corresponds to + // are visible to COM. + pIntfComMT->CheckParentComVisibility(FALSE); - RETURN GetComIPFromCCW_VisibilityCheck(pIntf, pIntfMT, pIntfComMT, flags); + // Giveout IClassX + IUnknown * pIntf = pWrap->GetIClassXIP(); + RETURN GetComIPFromCCW_VisibilityCheck(pIntf, pIntfMT, pIntfComMT, flags); + } + } } - - // for anything else, fall back to the CCW representing the "parent" class - RETURN GetComIPFromCCW(pWrap->GetSimpleWrapper()->GetClassWrapper(), riid, pIntfMT, flags); + RETURN NULL; } -// static -bool ComCallWrapper::GetComIPFromCCW_HandleCustomQI( +static bool GetComIPFromCCW_HandleCustomQI( ComCallWrapper * pWrap, REFIID riid, MethodTable * pIntfMT, IUnknown ** ppUnkOut) { CONTRACTL @@ -3243,86 +3429,6 @@ MethodTable * ComCallWrapper::GetMethodTableOfObjectRef() return GetObjectRef()->GetMethodTable(); } -// static -IUnknown * ComCallWrapper::GetComIPFromCCW_HandleExtendsCOMObject( - ComCallWrapper * pWrap, REFIID riid, MethodTable * pIntfMT, - ComCallWrapperTemplate * pTemplate, signed imapIndex, unsigned intfIndex) -{ - CONTRACTL - { - THROWS; - GC_TRIGGERS; - MODE_ANY; - } - CONTRACTL_END; - - BOOL bDelegateToBase = FALSE; - if (imapIndex != -1) - { - MethodTable * pMT = pWrap->GetMethodTableOfObjectRef(); - - // Check if this index is actually an interface implemented by us - // if it belongs to the base COM guy then we can hand over the call - // to him - if (pMT->IsWinRTObjectType()) - { - bDelegateToBase = pTemplate->GetComMTForIndex(intfIndex)->IsWinRTTrivialAggregate(); - } - else - { - MethodTable::InterfaceMapIterator intIt = pMT->IterateInterfaceMapFrom(intfIndex); - - // If the number of slots is 0, then no need to proceed - if (intIt.GetInterface()->GetNumVirtuals() != 0) - { - MethodDesc *pClsMD = NULL; - - // Find the implementation for the first slot of the interface - DispatchSlot impl(pMT->FindDispatchSlot(intIt.GetInterface()->GetTypeID(), 0, FALSE /* throwOnConflict */)); - CONSISTENCY_CHECK(!impl.IsNull()); - - // Get the MethodDesc for this slot in the class - pClsMD = impl.GetMethodDesc(); - - MethodTable * pClsMT = pClsMD->GetMethodTable(); - if (pClsMT->IsInterface() || pClsMT->IsComImport()) - bDelegateToBase = TRUE; - } - else - { - // The interface has no methods so we cannot override it. Because of this - // it makes sense to delegate to the base COM component. - bDelegateToBase = TRUE; - } - } - } - else - { - // If we don't implement the interface, we delegate to base - bDelegateToBase = TRUE; - } - - if (bDelegateToBase) - { - // This is an interface of the base COM guy so delegate the call to him - SyncBlock* pBlock = pWrap->GetSyncBlock(); - _ASSERTE(pBlock); - - SafeComHolder pUnk; - - RCWHolder pRCW(GetThread()); - RCWPROTECT_BEGIN(pRCW, pBlock); - - pUnk = (pIntfMT != NULL) ? pRCW->GetComIPFromRCW(pIntfMT) - : pRCW->GetComIPFromRCW(riid); - - RCWPROTECT_END(pRCW); - return pUnk.Extract(); - } - - return NULL; -} - //-------------------------------------------------------------------------- // IUnknown* ComCallWrapper::GetComIPfromCCW(ComCallWrapper *pWrap, REFIID riid, MethodTable* pIntfMT, BOOL bCheckVisibility) // Get an interface from wrapper, based on riid or pIntfMT. The returned interface is AddRef'd. @@ -3466,7 +3572,7 @@ IUnknown* ComCallWrapper::GetComIPFromCCW(ComCallWrapper *pWrap, REFIID riid, Me { // We haven't found an interface corresponding to the incoming pIntfMT/IID because we don't implement it. // However, we could still implement an interface that is castable to pIntfMT/IID via co-/contra-variance. - IUnknown * pIntf = pWrap->GetComIPFromCCWUsingVariance(riid, pIntfMT, flags); + IUnknown * pIntf = GetComIPFromCCW_UsingVariance(pWrap, riid, pIntfMT, flags); if (pIntf != NULL) { RETURN pIntf; @@ -3539,99 +3645,6 @@ IUnknown* ComCallWrapper::GetComIPFromCCW(ComCallWrapper *pWrap, REFIID riid, Me RETURN pIntf; } -// static -IUnknown * ComCallWrapper::GetComIPFromCCW_ForIID_Worker( - ComCallWrapper * pWrap, REFIID riid, MethodTable * pIntfMT, GetComIPFromCCW::flags flags, - ComCallWrapperTemplate * pTemplate) -{ - CONTRACT(IUnknown*) - { - THROWS; - GC_TRIGGERS; - MODE_ANY; - PRECONDITION(CheckPointer(pWrap)); - POSTCONDITION(CheckPointer(RETVAL, NULL_OK)); - } - CONTRACT_END; - - ComMethodTable * pIntfComMT = NULL; - MethodTable * pMT = pWrap->GetMethodTableOfObjectRef(); - - // At this point, it must be that the IID is one of IClassX IIDs or - // it isn't implemented on this class. We'll have to search through and set - // up the entire hierarchy to determine which it is. - if (IsIClassX(pMT, riid, &pIntfComMT)) - { - // If the class that this IClassX's was generated for is marked - // as ClassInterfaceType.AutoDual or AutoDisp, or it is a WinRT - // delegate, then give out the IClassX IP. - if (pIntfComMT->GetClassInterfaceType() == clsIfAutoDual || pIntfComMT->GetClassInterfaceType() == clsIfAutoDisp || - pIntfComMT->IsWinRTDelegate()) - { - // Make sure the all the base classes of the class this IClassX corresponds to - // are visible to COM. - pIntfComMT->CheckParentComVisibility(FALSE); - - // Giveout IClassX of this class because the IID matches one of the IClassX in the hierarchy - // This assumes any IClassX implementation must be derived from base class IClassX's implementation - IUnknown * pIntf = pWrap->GetIClassXIP(); - RETURN GetComIPFromCCW_VisibilityCheck(pIntf, pIntfMT, pIntfComMT, flags); - } - } - - RETURN NULL; -} - -// static -IUnknown * ComCallWrapper::GetComIPFromCCW_ForIntfMT_Worker( - ComCallWrapper * pWrap, MethodTable * pIntfMT, GetComIPFromCCW::flags flags) -{ - CONTRACT(IUnknown*) - { - THROWS; - GC_TRIGGERS; - MODE_ANY; - PRECONDITION(CheckPointer(pWrap)); - POSTCONDITION(CheckPointer(RETVAL, NULL_OK)); - } - CONTRACT_END; - - MethodTable * pMT = pWrap->GetMethodTableOfObjectRef(); - - // class method table - if (pMT->CanCastToClass(pIntfMT)) - { - // Make sure we're not trying to pass out a generic-based class interface (except for WinRT delegates) - if (pMT->HasInstantiation() && !pMT->SupportsGenericInterop(TypeHandle::Interop_NativeToManaged)) - { - COMPlusThrow(kInvalidOperationException, IDS_EE_ATTEMPT_TO_CREATE_GENERIC_CCW); - } - - // Retrieve the COM method table for the requested interface. - ComCallWrapperTemplate *pIntfCCWTemplate = ComCallWrapperTemplate::GetTemplate(TypeHandle(pIntfMT)); - if (pIntfCCWTemplate->SupportsIClassX()) - { - ComMethodTable * pIntfComMT = pIntfCCWTemplate->GetClassComMT(); - - // If the class that this IClassX's was generated for is marked - // as ClassInterfaceType.AutoDual or AutoDisp, or it is a WinRT - // delegate, then give out the IClassX IP. - if (pIntfComMT->GetClassInterfaceType() == clsIfAutoDual || pIntfComMT->GetClassInterfaceType() == clsIfAutoDisp || - pIntfComMT->IsWinRTDelegate()) - { - // Make sure the all the base classes of the class this IClassX corresponds to - // are visible to COM. - pIntfComMT->CheckParentComVisibility(FALSE); - - // Giveout IClassX - IUnknown * pIntf = pWrap->GetIClassXIP(); - RETURN GetComIPFromCCW_VisibilityCheck(pIntf, pIntfMT, pIntfComMT, flags); - } - } - } - RETURN NULL; -} - //-------------------------------------------------------------------------- // Get the IDispatch interface pointer for the wrapper. // The returned interface is AddRef'd. diff --git a/src/vm/comcallablewrapper.h b/src/vm/comcallablewrapper.h index 5773a9d8c8f8..104d7ec6fa2d 100644 --- a/src/vm/comcallablewrapper.h +++ b/src/vm/comcallablewrapper.h @@ -978,24 +978,21 @@ class ComCallWrapper { friend class MarshalNative; friend class ClrDataAccess; - + private: enum { -#ifdef _WIN64 NumVtablePtrs = 5, - enum_ThisMask = ~0x3f, // mask on IUnknown ** to get at the OBJECT-REF handle +#ifdef _WIN64 + enum_ThisMask = ~0x3f, // mask on IUnknown ** to get at the OBJECT-REF handle #else - - NumVtablePtrs = 5, enum_ThisMask = ~0x1f, // mask on IUnknown ** to get at the OBJECT-REF handle #endif - Slot_IClassX = 1, - Slot_Basic = 0, - + Slot_Basic = 0, + Slot_IClassX = 1, Slot_FirstInterface = 2, }; - + public: BOOL IsHandleWeak(); VOID MarkHandleWeak(); @@ -1012,7 +1009,7 @@ class ComCallWrapper protected: #ifndef DACCESS_COMPILE - inline static void SetNext(ComCallWrapper* pWrap, ComCallWrapper* pNextWrapper) + static void SetNext(ComCallWrapper* pWrap, ComCallWrapper* pNextWrapper) { CONTRACTL { @@ -1028,7 +1025,7 @@ class ComCallWrapper } #endif // !DACCESS_COMPILE - inline static PTR_ComCallWrapper GetNext(PTR_ComCallWrapper pWrap) + static PTR_ComCallWrapper GetNext(PTR_ComCallWrapper pWrap) { CONTRACT (PTR_ComCallWrapper) { @@ -1040,20 +1037,13 @@ class ComCallWrapper POSTCONDITION(CheckPointer(RETVAL, NULL_OK)); } CONTRACT_END; - + RETURN (LinkedWrapperTerminator == pWrap->m_pNext ? NULL : pWrap->m_pNext); } // Helper to create a wrapper, pClassCCW must be specified if pTemplate->RepresentsVariantInterface() static ComCallWrapper* CreateWrapper(OBJECTREF* pObj, ComCallWrapperTemplate *pTemplate, ComCallWrapper *pClassCCW); - // helper to get the IUnknown* within a wrapper - static SLOT** GetComIPLocInWrapper(ComCallWrapper* pWrap, unsigned iIndex); - - // helper to get index within the interface map for an interface that matches - // the interface MT - static signed GetIndexForIntfMT(ComCallWrapperTemplate *pTemplate, MethodTable *pIntfMT); - // helper to get wrapper from sync block static PTR_ComCallWrapper GetStartWrapper(PTR_ComCallWrapper pWrap); @@ -1062,36 +1052,10 @@ class ComCallWrapper ComCallWrapperCache *pWrapperCache, OBJECTHANDLE oh); - // helper to find a covariant supertype of pMT with the given IID - static MethodTable *FindCovariantSubtype(MethodTable *pMT, REFIID riid); - - // Like GetComIPFromCCW, but will try to find riid/pIntfMT among interfaces implemented by this - // object that have variance. Assumes that call GetComIPFromCCW with same arguments has failed. - IUnknown *GetComIPFromCCWUsingVariance(REFIID riid, MethodTable *pIntfMT, GetComIPFromCCW::flags flags); - - static IUnknown * GetComIPFromCCW_VariantInterface( - ComCallWrapper * pWrap, REFIID riid, MethodTable * pIntfMT, GetComIPFromCCW::flags flags, - ComCallWrapperTemplate * pTemplate); - - inline static IUnknown * GetComIPFromCCW_VisibilityCheck( - IUnknown * pIntf, MethodTable * pIntfMT, ComMethodTable * pIntfComMT, - GetComIPFromCCW::flags flags); - - static IUnknown * GetComIPFromCCW_HandleExtendsCOMObject( - ComCallWrapper * pWrap, REFIID riid, MethodTable * pIntfMT, - ComCallWrapperTemplate * pTemplate, signed imapIndex, unsigned intfIndex); - - static IUnknown * GetComIPFromCCW_ForIID_Worker( - ComCallWrapper * pWrap, REFIID riid, MethodTable * pIntfMT, GetComIPFromCCW::flags flags, - ComCallWrapperTemplate * pTemplate); - - static IUnknown * GetComIPFromCCW_ForIntfMT_Worker( - ComCallWrapper * pWrap, MethodTable * pIntfMT, GetComIPFromCCW::flags flags); - + static SLOT** GetComIPLocInWrapper(ComCallWrapper* pWrap, unsigned int iIndex); public: - static bool GetComIPFromCCW_HandleCustomQI( - ComCallWrapper * pWrap, REFIID riid, MethodTable * pIntfMT, IUnknown ** ppUnkOut); + SLOT** GetFirstInterfaceSlot(); // walk the list and free all blocks void FreeWrapper(ComCallWrapperCache *pWrapperCache); @@ -1113,7 +1077,6 @@ class ComCallWrapper return m_pNext != NULL; } - // wrapper is not guaranteed to be present // accessor to wrapper object in the sync block inline static PTR_ComCallWrapper GetWrapperForObject(OBJECTREF pObj, ComCallWrapperTemplate *pTemplate = NULL) diff --git a/tests/src/Interop/COM/NETServer/NETServer.DefaultInterfaces.il b/tests/src/Interop/COM/NETServer/NETServer.DefaultInterfaces.il index 3e4180c7b7cb..a45fc521069d 100644 --- a/tests/src/Interop/COM/NETServer/NETServer.DefaultInterfaces.il +++ b/tests/src/Interop/COM/NETServer/NETServer.DefaultInterfaces.il @@ -62,16 +62,17 @@ ret } - .method public abstract virtual - instance int32 DefOnInterface2Ret5() cil managed + .method public virtual + instance int32 DefOnInterfaceRet5() cil managed { - // Simple interface method + // Default interface implementation + ldc.i4 5 + ret } } // Interface defintion for overriding another interface's implementation -.class interface private abstract auto ansi Server.Contract.IDefaultInterfaceTesting2 - implements Server.Contract.IDefaultInterfaceTesting +.class interface public abstract auto ansi Server.Contract.IDefaultInterfaceTesting2 { .custom instance void [System.Runtime]System.Runtime.InteropServices.ComVisibleAttribute::.ctor(bool) = ( 01 00 01 00 00 ) @@ -82,20 +83,13 @@ 39 34 42 46 37 41 30 00 00 ) .custom instance void [System.Runtime.InteropServices]System.Runtime.InteropServices.InterfaceTypeAttribute::.ctor(valuetype [System.Runtime.InteropServices]System.Runtime.InteropServices.ComInterfaceType) = ( 01 00 01 00 00 00 00 00 ) - - .method public virtual final - instance int32 DefOnInterface2Ret5() cil managed - { - .override Server.Contract.IDefaultInterfaceTesting::DefOnInterface2Ret5 - ldc.i4 5 - ret - } } // COM server consuming interfaces with default methods .class public auto ansi beforefieldinit DefaultInterfaceTesting extends [System.Runtime]System.Object - implements Server.Contract.IDefaultInterfaceTesting2 + implements Server.Contract.IDefaultInterfaceTesting, + Server.Contract.IDefaultInterfaceTesting2 { .custom instance void [System.Runtime]System.Runtime.InteropServices.ComVisibleAttribute::.ctor(bool) = ( 01 00 01 00 00 ) // GUID: FAEF42AE-C1A4-419F-A912-B768AC2679EA diff --git a/tests/src/Interop/COM/NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp b/tests/src/Interop/COM/NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp index 357bea6408e2..1751abbbc5a8 100644 --- a/tests/src/Interop/COM/NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp +++ b/tests/src/Interop/COM/NativeClients/DefaultInterfaces/DefaultInterfacesTests.cpp @@ -34,7 +34,9 @@ struct ComInit using ComMTA = ComInit; -void CallDefaultInterface(); +void ActivateClassWithDefaultInterfaces(); +void FailToActivateDefaultInterfaceInstance(); +void FailToQueryInterfaceForDefaultInterface(); int __cdecl main() { @@ -46,7 +48,9 @@ int __cdecl main() { CoreShimComActivation csact{ W("NetServer.DefaultInterfaces"), W("DefaultInterfaceTesting") }; - CallDefaultInterface(); + ActivateClassWithDefaultInterfaces(); + FailToActivateDefaultInterfaceInstance(); + FailToQueryInterfaceForDefaultInterface(); } catch (HRESULT hr) { @@ -57,15 +61,52 @@ int __cdecl main() return 100; } -void CallDefaultInterface() +void ActivateClassWithDefaultInterfaces() { - ::printf("Call functions on Default Interface...\n"); + ::printf("Activate class using default interfaces via IUnknown...\n"); + + HRESULT hr; + + // Validate a class that has an interface with function definitions can be activated + { + ComSmartPtr unknown; + THROW_IF_FAILED(::CoCreateInstance(CLSID_DefaultInterfaceTesting, nullptr, CLSCTX_INPROC, IID_IUnknown, (void**)&unknown)); + THROW_FAIL_IF_FALSE(unknown != nullptr); + } + + { + ComSmartPtr classFactory; + THROW_IF_FAILED(::CoGetClassObject(CLSID_DefaultInterfaceTesting, CLSCTX_INPROC, nullptr, IID_IClassFactory, (void**)&classFactory)); + + ComSmartPtr unknown; + THROW_IF_FAILED(classFactory->CreateInstance(nullptr, IID_IUnknown, (void**)&unknown)); + THROW_FAIL_IF_FALSE(unknown != nullptr); + } +} + +const int COR_E_INVALIDOPERATION = 0x80131509; + +void FailToActivateDefaultInterfaceInstance() +{ + ::printf("Fail to activate class via a default interface...\n"); HRESULT hr; ComSmartPtr defInterface; hr = ::CoCreateInstance(CLSID_DefaultInterfaceTesting, nullptr, CLSCTX_INPROC, IID_IDefaultInterfaceTesting, (void**)&defInterface); + THROW_FAIL_IF_FALSE(hr == COR_E_INVALIDOPERATION); +} + +void FailToQueryInterfaceForDefaultInterface() +{ + ::printf("Fail to QueryInterface() for default interface...\n"); + + HRESULT hr; - const int COR_E_INVALIDOPERATION = 0x80131509; + ComSmartPtr defInterface2; + THROW_IF_FAILED(::CoCreateInstance(CLSID_DefaultInterfaceTesting, nullptr, CLSCTX_INPROC, IID_IDefaultInterfaceTesting2, (void**)&defInterface2)); + + ComSmartPtr defInterface; + hr = defInterface2->QueryInterface(&defInterface); THROW_FAIL_IF_FALSE(hr == COR_E_INVALIDOPERATION); } diff --git a/tests/src/Interop/COM/ServerContracts/Server.Contracts.h b/tests/src/Interop/COM/ServerContracts/Server.Contracts.h index a3629c140cdc..ded28c215c9a 100644 --- a/tests/src/Interop/COM/ServerContracts/Server.Contracts.h +++ b/tests/src/Interop/COM/ServerContracts/Server.Contracts.h @@ -488,13 +488,13 @@ IDefaultInterfaceTesting : IUnknown virtual HRESULT STDMETHODCALLTYPE DefOnClassRet3(_Out_ int *p) = 0; - virtual HRESULT STDMETHODCALLTYPE DefOnInterface2Ret5(_Out_ int *p) = 0; + virtual HRESULT STDMETHODCALLTYPE DefOnInterfaceRet5(_Out_ int *p) = 0; }; struct __declspec(uuid("9B3CE792-F063-427D-B48E-4354094BF7A0")) -IDefaultInterfaceTesting2 : IDefaultInterfaceTesting +IDefaultInterfaceTesting2 : IUnknown { - + // Empty }; #pragma pack(pop) From da1e7b2d0f63e862cb8183e88dc45542a540ee2c Mon Sep 17 00:00:00 2001 From: Aaron Robinson Date: Wed, 17 Apr 2019 23:05:07 -0700 Subject: [PATCH 4/5] Address memory leak --- .../Runtime/InteropServices/ComActivator.cs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs index 2001e58d6114..3732f1886ed6 100644 --- a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs +++ b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs @@ -264,14 +264,24 @@ public static Type GetValidatedInterfaceType(Type classType, ref Guid riid, obje throw new InvalidCastException(); } - public static void ValidateInterfaceIsMarshallable(object obj, Type interfaceType) + public static void ValidateInterfaceIsMarshallable(object? obj, Type interfaceType) { Debug.Assert(obj != null && interfaceType != null); if (interfaceType != typeof(object)) { Debug.Assert(interfaceType.IsInterface); - Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore); + + // The intent of this call is to validate the interface can be + // marshalled to native code. An exception will be thrown if the + // type is unable to be marshalled to native code. + // Scenarios where this is relevant: + // - Interfaces that use Generics + // - Interfaces that define implementation + IntPtr ptr = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore); + + // Decrement the above 'Marshal.GetComInterfaceForObject()' + Marshal.ReleaseComObject(ptr); } } From e3fbf1780badef4807ce94238beb9ba5fe789440 Mon Sep 17 00:00:00 2001 From: Aaron Robinson Date: Thu, 18 Apr 2019 11:16:07 -0700 Subject: [PATCH 5/5] Fix release logic and cleanup merge with nullable work. --- .../Runtime/InteropServices/ComActivator.cs | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs index 3732f1886ed6..e080a31f1b38 100644 --- a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs +++ b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs @@ -264,25 +264,27 @@ public static Type GetValidatedInterfaceType(Type classType, ref Guid riid, obje throw new InvalidCastException(); } - public static void ValidateInterfaceIsMarshallable(object? obj, Type interfaceType) + public static void ValidateObjectIsMarshallableAsInterface(object obj, Type interfaceType) { - Debug.Assert(obj != null && interfaceType != null); - - if (interfaceType != typeof(object)) + // If the requested "interface type" is type object then return + // because type object is always marshallable. + if (interfaceType == typeof(object)) { - Debug.Assert(interfaceType.IsInterface); - - // The intent of this call is to validate the interface can be - // marshalled to native code. An exception will be thrown if the - // type is unable to be marshalled to native code. - // Scenarios where this is relevant: - // - Interfaces that use Generics - // - Interfaces that define implementation - IntPtr ptr = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore); - - // Decrement the above 'Marshal.GetComInterfaceForObject()' - Marshal.ReleaseComObject(ptr); + return; } + + Debug.Assert(interfaceType.IsInterface); + + // The intent of this call is to validate the interface can be + // marshalled to native code. An exception will be thrown if the + // type is unable to be marshalled to native code. + // Scenarios where this is relevant: + // - Interfaces that use Generics + // - Interfaces that define implementation + IntPtr ptr = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore); + + // Decrement the above 'Marshal.GetComInterfaceForObject()' + Marshal.Release(ptr); } public static object CreateAggregatedObject(object pUnkOuter, object comObject) @@ -298,7 +300,7 @@ public static object CreateAggregatedObject(object pUnkOuter, object comObject) finally { // Decrement the above 'Marshal.GetIUnknownForObject()' - Marshal.ReleaseComObject(pUnkOuter); + Marshal.Release(outerPtr); } } @@ -315,7 +317,7 @@ public void CreateInstance( ppvObject = BasicClassFactory.CreateAggregatedObject(pUnkOuter, ppvObject); } - BasicClassFactory.ValidateInterfaceIsMarshallable(ppvObject, interfaceType); + BasicClassFactory.ValidateObjectIsMarshallableAsInterface(ppvObject, interfaceType); } public void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock) @@ -394,7 +396,7 @@ private void CreateInstanceInner( ppvObject = BasicClassFactory.CreateAggregatedObject(pUnkOuter, ppvObject); } - BasicClassFactory.ValidateInterfaceIsMarshallable(ppvObject, interfaceType); + BasicClassFactory.ValidateObjectIsMarshallableAsInterface(ppvObject, interfaceType); } } } @@ -586,12 +588,12 @@ public void GetCurrentContextInfo(RuntimeTypeHandle rth, out bool isDesignTime, // Types are as follows: // Type, out bool, out string -> LicenseContext - var parameters = new object[] { targetRcwTypeMaybe, /* out */ null!, /* out */ null! }; + var parameters = new object?[] { targetRcwTypeMaybe, /* out */ null, /* out */ null }; _licContext = _getCurrentContextInfo.Invoke(null, BindingFlags.DoNotWrapExceptions, binder: null, parameters: parameters, culture: null); _targetRcwType = targetRcwTypeMaybe; - isDesignTime = (bool)parameters[1]; - bstrKey = Marshal.StringToBSTR((string?)parameters[2]); + isDesignTime = (bool)parameters[1]!; + bstrKey = Marshal.StringToBSTR((string)parameters[2]!); } // The CLR invokes this when instantiating a licensed COM