diff --git a/src/coreclr/vm/custommarshalerinfo.cpp b/src/coreclr/vm/custommarshalerinfo.cpp index 67acbff136ec36..0af48f8e715765 100644 --- a/src/coreclr/vm/custommarshalerinfo.cpp +++ b/src/coreclr/vm/custommarshalerinfo.cpp @@ -67,13 +67,6 @@ CustomMarshalerInfo::CustomMarshalerInfo(LoaderAllocator *pLoaderAllocator, Type STRINGREF CookieStringObj = StringObject::NewString(strCookie, cCookieStrBytes); GCPROTECT_BEGIN(CookieStringObj); #endif - - // Load the method desc's for all the methods in the ICustomMarshaler interface. - m_pMarshalNativeToManagedMD = GetCustomMarshalerMD(CustomMarshalerMethods_MarshalNativeToManaged, hndCustomMarshalerType); - m_pMarshalManagedToNativeMD = GetCustomMarshalerMD(CustomMarshalerMethods_MarshalManagedToNative, hndCustomMarshalerType); - m_pCleanUpNativeDataMD = GetCustomMarshalerMD(CustomMarshalerMethods_CleanUpNativeData, hndCustomMarshalerType); - m_pCleanUpManagedDataMD = GetCustomMarshalerMD(CustomMarshalerMethods_CleanUpManagedData, hndCustomMarshalerType); - // Load the method desc for the static method to retrieve the instance. MethodDesc *pGetCustomMarshalerMD = GetCustomMarshalerMD(CustomMarshalerMethods_GetInstance, hndCustomMarshalerType); @@ -103,7 +96,9 @@ CustomMarshalerInfo::CustomMarshalerInfo(LoaderAllocator *pLoaderAllocator, Type }; // Call the GetCustomMarshaler method to retrieve the custom marshaler to use. - OBJECTREF CustomMarshalerObj = getCustomMarshaler.Call_RetOBJECTREF(GetCustomMarshalerArgs); + OBJECTREF CustomMarshalerObj = NULL; + GCPROTECT_BEGIN(CustomMarshalerObj); + CustomMarshalerObj = getCustomMarshaler.Call_RetOBJECTREF(GetCustomMarshalerArgs); if (!CustomMarshalerObj) { DefineFullyQualifiedNameForClassW() @@ -111,7 +106,16 @@ CustomMarshalerInfo::CustomMarshalerInfo(LoaderAllocator *pLoaderAllocator, Type IDS_EE_NOCUSTOMMARSHALER, GetFullyQualifiedNameForClassW(hndCustomMarshalerType.GetMethodTable())); } + // Load the method desc's for all the methods in the ICustomMarshaler interface based on the type of the marshaler object. + TypeHandle customMarshalerObjType = CustomMarshalerObj->GetMethodTable(); + + m_pMarshalNativeToManagedMD = GetCustomMarshalerMD(CustomMarshalerMethods_MarshalNativeToManaged, customMarshalerObjType); + m_pMarshalManagedToNativeMD = GetCustomMarshalerMD(CustomMarshalerMethods_MarshalManagedToNative, customMarshalerObjType); + m_pCleanUpNativeDataMD = GetCustomMarshalerMD(CustomMarshalerMethods_CleanUpNativeData, customMarshalerObjType); + m_pCleanUpManagedDataMD = GetCustomMarshalerMD(CustomMarshalerMethods_CleanUpManagedData, customMarshalerObjType); + m_hndCustomMarshaler = pLoaderAllocator->AllocateHandle(CustomMarshalerObj); + GCPROTECT_END(); // Retrieve the size of the native data. if (m_bDataIsByValue) diff --git a/src/tests/Interop/ICustomMarshaler/Primitives/ICustomMarshaler.cs b/src/tests/Interop/ICustomMarshaler/Primitives/ICustomMarshaler.cs index 271c7485d19068..78d7d541872a4a 100644 --- a/src/tests/Interop/ICustomMarshaler/Primitives/ICustomMarshaler.cs +++ b/src/tests/Interop/ICustomMarshaler/Primitives/ICustomMarshaler.cs @@ -365,7 +365,7 @@ public void Parameter_NotICustomMarshaler_ThrowsApplicationException() { Assert.Throws(() => NonICustomMarshalerMethod("")); } - + [DllImport(LibcLibrary, EntryPoint = "atoi", CallingConvention = CallingConvention.Cdecl)] public static extern int NonICustomMarshalerMethod([MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(string))] string str); @@ -509,7 +509,7 @@ public void CleanUpNativeData(IntPtr pNativeData) { } [Fact] public void Parameter_GetInstanceMethodThrows_ThrowsActualException() - { + { Assert.Throws(() => ThrowingGetInstanceMethod("")); } @@ -588,6 +588,58 @@ public struct StructWithCustomMarshalerField [DllImport(LibcLibrary, EntryPoint = "atoi", CallingConvention = CallingConvention.Cdecl)] public static extern int StructWithCustomMarshalerFieldMethod(StructWithCustomMarshalerField c); + + [Fact] + public void Parameter_DifferentCustomMarshalerType_MarshalsCorrectly() + { + Assert.Equal(234, DifferentCustomMarshalerType("5678")); + } + + public class OuterCustomMarshaler : ICustomMarshaler + { + public void CleanUpManagedData(object ManagedObj) => throw new NotImplementedException(); + public void CleanUpNativeData(IntPtr pNativeData) => throw new NotImplementedException(); + + public int GetNativeDataSize() => throw new NotImplementedException(); + + public IntPtr MarshalManagedToNative(object ManagedObj) => throw new NotImplementedException(); + public object MarshalNativeToManaged(IntPtr pNativeData) => throw new NotImplementedException(); + + public static ICustomMarshaler GetInstance(string cookie) => new InnerCustomMarshaler(); + + private interface ILargeInterface + { + void Method1(); + void Method2(); + void Method3(); + void Method4(); + void Method5(); + void Method6(); + } + + private class InnerCustomMarshaler : ILargeInterface, ICustomMarshaler + { + public void Method1() => throw new InvalidOperationException(); + public void Method2() => throw new InvalidOperationException(); + public void Method3() => throw new InvalidOperationException(); + public void Method4() => throw new InvalidOperationException(); + public void Method5() => throw new InvalidOperationException(); + public void Method6() => throw new InvalidOperationException(); + + public void CleanUpManagedData(object ManagedObj) { } + public void CleanUpNativeData(IntPtr pNativeData) => Marshal.FreeCoTaskMem(pNativeData); + + public int GetNativeDataSize() => IntPtr.Size; + + public IntPtr MarshalManagedToNative(object ManagedObj) => Marshal.StringToCoTaskMemAnsi("234"); + public object MarshalNativeToManaged(IntPtr pNativeData) => null; + } + } + + [DllImport(LibcLibrary, EntryPoint = "atoi", CallingConvention = CallingConvention.Cdecl)] + public static extern int DifferentCustomMarshalerType([MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(OuterCustomMarshaler))] string str); + + public static int Main(String[] args) { return new ICustomMarshalerTests().RunTests();