diff --git a/src/coreclr/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs b/src/coreclr/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs index 97065b11c2b0ae..d0b63ccc546ec1 100644 --- a/src/coreclr/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs +++ b/src/coreclr/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs @@ -542,12 +542,39 @@ public BasicClassFactory(Guid clsid, [DynamicallyAccessedMembers(DynamicallyAcce _classType = classType; } - public static Type GetValidatedInterfaceType([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] Type classType, ref Guid riid, object? outer) + public enum ValidatedInterfaceKind + { + IUnknown, + IDispatch, + ManagedType, + } + + public struct ValidatedInterfaceType + { + public ValidatedInterfaceKind Kind { get; init; } + public Type? ManagedType { get; init; } + } + + public static ValidatedInterfaceType CreateValidatedInterfaceType([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] Type classType, ref Guid riid, object? outer) { Debug.Assert(classType != null); if (riid == Marshal.IID_IUnknown) { - return typeof(object); + return new ValidatedInterfaceType() { Kind = ValidatedInterfaceKind.IUnknown, ManagedType = null }; + } + else if (riid == Marshal.IID_IDispatch) + { + ClassInterfaceAttribute? attr = + classType.GetCustomAttribute() + ?? classType.Assembly.GetCustomAttribute(); // If there is no attribute on the Type, check the Assembly. + + // If the attribute is unspecified, the default is ClassInterfaceType.AutoDispatch. + // See DEFAULT_CLASS_INTERFACE_TYPE in native. + if (attr is null + || attr.Value is ClassInterfaceType.AutoDispatch or ClassInterfaceType.AutoDual) + { + return new ValidatedInterfaceType() { Kind = ValidatedInterfaceKind.IDispatch, ManagedType = null }; + } } // Aggregation can only be done when requesting IUnknown. @@ -562,7 +589,7 @@ public static Type GetValidatedInterfaceType([DynamicallyAccessedMembers(Dynamic { if (i.GUID == riid) { - return i; + return new ValidatedInterfaceType() { Kind = ValidatedInterfaceKind.ManagedType, ManagedType = i }; } } @@ -570,15 +597,22 @@ public static Type GetValidatedInterfaceType([DynamicallyAccessedMembers(Dynamic throw new InvalidCastException(); } - public static IntPtr GetObjectAsInterface(object obj, Type interfaceType) + public static IntPtr GetObjectAsInterface(object obj, ValidatedInterfaceType interfaceType) { - // If the requested "interface type" is type object then return as IUnknown - if (interfaceType == typeof(object)) + if (interfaceType.Kind is ValidatedInterfaceKind.IUnknown) { + Debug.Assert(interfaceType.ManagedType is null); return Marshal.GetIUnknownForObject(obj); } + else if (interfaceType.Kind is ValidatedInterfaceKind.IDispatch) + { + Debug.Assert(interfaceType.ManagedType is null); + return Marshal.GetIDispatchForObject(obj); + } - Debug.Assert(interfaceType.IsInterface); + Debug.Assert(interfaceType.Kind is ValidatedInterfaceKind.ManagedType + && interfaceType.ManagedType != null + && interfaceType.ManagedType.IsInterface); // The intent of this call is to get AND validate the interface can be // marshalled to native code. An exception will be thrown if the @@ -586,7 +620,7 @@ public static IntPtr GetObjectAsInterface(object obj, Type interfaceType) // Scenarios where this is relevant: // - Interfaces that use Generics // - Interfaces that define implementation - IntPtr interfaceMaybe = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore); + IntPtr interfaceMaybe = Marshal.GetComInterfaceForObject(obj, interfaceType.ManagedType, CustomQueryInterfaceMode.Ignore); if (interfaceMaybe == IntPtr.Zero) { @@ -620,7 +654,7 @@ public void CreateInstance( ref Guid riid, out IntPtr ppvObject) { - Type interfaceType = GetValidatedInterfaceType(_classType, ref riid, pUnkOuter); + var interfaceType = CreateValidatedInterfaceType(_classType, ref riid, pUnkOuter); object obj = Activator.CreateInstance(_classType)!; if (pUnkOuter != null) @@ -700,7 +734,7 @@ private void CreateInstanceInner( bool isDesignTime, out IntPtr ppvObject) { - Type interfaceType = BasicClassFactory.GetValidatedInterfaceType(_classType, ref riid, pUnkOuter); + var interfaceType = BasicClassFactory.CreateValidatedInterfaceType(_classType, ref riid, pUnkOuter); object obj = _licenseProxy.AllocateAndValidateLicense(_classType, key, isDesignTime); if (pUnkOuter != null) diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs index 82a9e5d9ec258b..d7894da107dd14 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs @@ -22,6 +22,11 @@ public static partial class Marshal /// IUnknown is {00000000-0000-0000-C000-000000000046} /// internal static readonly Guid IID_IUnknown = new Guid(0, 0, 0, 0xC0, 0, 0, 0, 0, 0, 0, 0x46); + + /// + /// IDispatch is {00020400-0000-0000-C000-000000000046} + /// + internal static readonly Guid IID_IDispatch = new Guid(0x00020400, 0, 0, 0xC0, 0, 0, 0, 0, 0, 0, 0x46); #endif //FEATURE_COMINTEROP internal static int SizeOfHelper(RuntimeType t, [MarshalAs(UnmanagedType.Bool)] bool throwIfNotMarshalable) diff --git a/src/tests/Interop/COM/NETServer/ClassInterfaceTesting.cs b/src/tests/Interop/COM/NETServer/ClassInterfaceTesting.cs new file mode 100644 index 00000000000000..d35b84fb11a6c3 --- /dev/null +++ b/src/tests/Interop/COM/NETServer/ClassInterfaceTesting.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; + +#pragma warning disable 618 // Must test deprecated features + +[ComVisible(true)] +[Guid(Server.Contract.Guids.ClassInterfaceNotSetTesting)] +public class ClassInterfaceNotSetTesting +{ +} + +[ComVisible(true)] +[Guid(Server.Contract.Guids.ClassInterfaceNoneTesting)] +[ClassInterface(ClassInterfaceType.None)] +public class ClassInterfaceNoneTesting +{ +} + +[ComVisible(true)] +[Guid(Server.Contract.Guids.ClassInterfaceAutoDispatchTesting)] +[ClassInterface(ClassInterfaceType.AutoDispatch)] +public class ClassInterfaceAutoDispatchTesting +{ +} + +[ComVisible(true)] +[Guid(Server.Contract.Guids.ClassInterfaceAutoDualTesting)] +[ClassInterface(ClassInterfaceType.AutoDual)] +public class ClassInterfaceAutoDualTesting +{ +} + +#pragma warning restore 618 // Must test deprecated features \ No newline at end of file diff --git a/src/tests/Interop/COM/NativeClients/MiscTypes/CoreShim.X.manifest b/src/tests/Interop/COM/NativeClients/MiscTypes/CoreShim.X.manifest index a3c8593ee06761..457ae854700392 100644 --- a/src/tests/Interop/COM/NativeClients/MiscTypes/CoreShim.X.manifest +++ b/src/tests/Interop/COM/NativeClients/MiscTypes/CoreShim.X.manifest @@ -11,6 +11,22 @@ + + + + + + + + diff --git a/src/tests/Interop/COM/NativeClients/MiscTypes/MiscTypes.cpp b/src/tests/Interop/COM/NativeClients/MiscTypes/MiscTypes.cpp index ad0bb044c8e4e0..7c0adb003b807d 100644 --- a/src/tests/Interop/COM/NativeClients/MiscTypes/MiscTypes.cpp +++ b/src/tests/Interop/COM/NativeClients/MiscTypes/MiscTypes.cpp @@ -35,6 +35,7 @@ struct ComInit using ComMTA = ComInit; void ValidationTests(); void ValidationByRefTests(); +void ValidationClassInterfaceTests(); int __cdecl main() { @@ -59,6 +60,16 @@ int __cdecl main() return 101; } + try + { + ValidationClassInterfaceTests(); + } + catch (HRESULT hr) + { + ::printf("Test Failure: 0x%08x\n", hr); + return 101; + } + return 100; } @@ -84,7 +95,7 @@ void ValidationTests() HRESULT hr; - IMiscTypesTesting *miscTypesTesting; + ComSmartPtr miscTypesTesting; THROW_IF_FAILED(::CoCreateInstance(CLSID_MiscTypesTesting, nullptr, CLSCTX_INPROC, IID_IMiscTypesTesting, (void**)&miscTypesTesting)); ::printf("-- Primitives <=> VARIANT...\n"); @@ -328,7 +339,7 @@ void ValidationByRefTests() HRESULT hr; - IMiscTypesTesting *miscTypesTesting; + ComSmartPtr miscTypesTesting; THROW_IF_FAILED(::CoCreateInstance(CLSID_MiscTypesTesting, nullptr, CLSCTX_INPROC, IID_IMiscTypesTesting, (void**)&miscTypesTesting)); ::printf("-- Primitives <=> BYREF VARIANT...\n"); @@ -365,7 +376,7 @@ void ValidationByRefTests() THROW_FAIL_IF_FALSE(CompareStringOrdinal(expected, -1, value, -1, FALSE) == CSTR_EQUAL); ::SysFreeString(expected); } - + ::printf("-- System.Guid <=> BYREF VARIANT...\n"); { /* 8EFAD956-B33D-46CB-90F4-45F55BA68A96 */ @@ -378,7 +389,7 @@ void ValidationByRefTests() THROW_FAIL_IF_FALSE(memcmp(V_RECORD(&guidVar.Input), &expected, sizeof(expected)) == 0); THROW_IF_FAILED(miscTypesTesting->Marshal_Instance_Variant(W("{00000000-0000-0000-0000-000000000000}"), &guidVar.Result)); THROW_FAIL_IF_FALSE(V_VT(&guidVar.Result) == VT_RECORD); - + // Use the Guid as input. VariantMarshalTest args{}; THROW_IF_FAILED(::VariantCopy(&args.Input, &guidVar.Input)); @@ -399,3 +410,47 @@ void ValidationByRefTests() THROW_FAIL_IF_FALSE(miscTypesTesting->Marshal_ByRefVariant(&args.Result, args.Input) == 0x80004002); // COR_E_INVALIDCAST } } + +void ValidationClassInterfaceTests() +{ + ::printf(__FUNCTION__ "() through CoCreateInstance...\n"); + + HRESULT hr; + + { + CoreShimComActivation csact{ W("NETServer"), W("ClassInterfaceNotSetTesting") }; + ::printf("-- ClassInterfaceType not set ...\n"); + + ComSmartPtr pUnk; + THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceNotSetTesting, nullptr, CLSCTX_INPROC, IID_IUnknown, (void**)&pUnk)); + ComSmartPtr pDisp; + THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceNotSetTesting, nullptr, CLSCTX_INPROC, IID_IDispatch, (void**)&pDisp)); + } + { + CoreShimComActivation csact{ W("NETServer"), W("ClassInterfaceNoneTesting") }; + ::printf("-- ClassInterfaceType.None ...\n"); + + ComSmartPtr pUnk; + THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceNoneTesting, nullptr, CLSCTX_INPROC, IID_IUnknown, (void**)&pUnk)); + ComSmartPtr pDisp; + THROW_FAIL_IF_FALSE(E_NOINTERFACE == ::CoCreateInstance(CLSID_ClassInterfaceNoneTesting, nullptr, CLSCTX_INPROC, IID_IDispatch, (void**)&pDisp)); + } + { + CoreShimComActivation csact{ W("NETServer"), W("ClassInterfaceAutoDispatchTesting") }; + ::printf("-- ClassInterfaceType.AutoDispatch ...\n"); + + ComSmartPtr pUnk; + THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceAutoDispatchTesting, nullptr, CLSCTX_INPROC, IID_IUnknown, (void**)&pUnk)); + ComSmartPtr pDisp; + THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceAutoDispatchTesting, nullptr, CLSCTX_INPROC, IID_IDispatch, (void**)&pDisp)); + } + { + CoreShimComActivation csact{ W("NETServer"), W("ClassInterfaceAutoDualTesting") }; + ::printf("-- ClassInterfaceType.AutoDual ...\n"); + + ComSmartPtr pUnk; + THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceAutoDualTesting, nullptr, CLSCTX_INPROC, IID_IUnknown, (void**)&pUnk)); + ComSmartPtr pDisp; + THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceAutoDualTesting, nullptr, CLSCTX_INPROC, IID_IDispatch, (void**)&pDisp)); + } +} \ No newline at end of file diff --git a/src/tests/Interop/COM/NativeServer/Servers.h b/src/tests/Interop/COM/NativeServer/Servers.h index 44e5070a25a6b2..f0d71b45c0e969 100644 --- a/src/tests/Interop/COM/NativeServer/Servers.h +++ b/src/tests/Interop/COM/NativeServer/Servers.h @@ -23,6 +23,10 @@ class DECLSPEC_UUID("66DB7882-E2B0-471D-92C7-B2B52A0EA535") LicenseTesting; class DECLSPEC_UUID("FAEF42AE-C1A4-419F-A912-B768AC2679EA") DefaultInterfaceTesting; class DECLSPEC_UUID("CE137261-6F19-44F5-A449-EF963B3F987E") InspectableTesting; class DECLSPEC_UUID("4F54231D-9E11-4C0B-8E0B-2EBD8B0E5811") TrackMyLifetimeTesting; +class DECLSPEC_UUID("B8314D5A-DE70-435B-AD97-8F88820D1F3C") ClassInterfaceNotSetTesting; +class DECLSPEC_UUID("ED4D9C70-1C9F-406B-B51F-87DD977AF3B2") ClassInterfaceNoneTesting; +class DECLSPEC_UUID("C1A0AE72-791B-4380-946E-B7BABDEA1701") ClassInterfaceAutoDispatchTesting; +class DECLSPEC_UUID("95696E2C-742F-4639-A9D4-5D36EE021C49") ClassInterfaceAutoDualTesting; #define CLSID_NumericTesting __uuidof(NumericTesting) #define CLSID_ArrayTesting __uuidof(ArrayTesting) @@ -38,6 +42,10 @@ class DECLSPEC_UUID("4F54231D-9E11-4C0B-8E0B-2EBD8B0E5811") TrackMyLifetimeTesti #define CLSID_DefaultInterfaceTesting __uuidof(DefaultInterfaceTesting) #define CLSID_InspectableTesting __uuidof(InspectableTesting) #define CLSID_TrackMyLifetimeTesting __uuidof(TrackMyLifetimeTesting) +#define CLSID_ClassInterfaceNotSetTesting __uuidof(ClassInterfaceNotSetTesting) +#define CLSID_ClassInterfaceNoneTesting __uuidof(ClassInterfaceNoneTesting) +#define CLSID_ClassInterfaceAutoDispatchTesting __uuidof(ClassInterfaceAutoDispatchTesting) +#define CLSID_ClassInterfaceAutoDualTesting __uuidof(ClassInterfaceAutoDualTesting) #define IID_INumericTesting __uuidof(INumericTesting) #define IID_IArrayTesting __uuidof(IArrayTesting) diff --git a/src/tests/Interop/COM/ServerContracts/ServerGuids.cs b/src/tests/Interop/COM/ServerContracts/ServerGuids.cs index 6c6b6569b634a7..7611e588dcbd0f 100644 --- a/src/tests/Interop/COM/ServerContracts/ServerGuids.cs +++ b/src/tests/Interop/COM/ServerContracts/ServerGuids.cs @@ -23,5 +23,9 @@ internal sealed class Guids public const string ConsumeNETServerTesting = "DE4ACF53-5957-4D31-8BE2-EA6C80683246"; public const string InspectableTesting = "CE137261-6F19-44F5-A449-EF963B3F987E"; public const string TrackMyLifetimeTesting = "4F54231D-9E11-4C0B-8E0B-2EBD8B0E5811"; + public const string ClassInterfaceNotSetTesting = "B8314D5A-DE70-435B-AD97-8F88820D1F3C"; + public const string ClassInterfaceNoneTesting = "ED4D9C70-1C9F-406B-B51F-87DD977AF3B2"; + public const string ClassInterfaceAutoDispatchTesting = "C1A0AE72-791B-4380-946E-B7BABDEA1701"; + public const string ClassInterfaceAutoDualTesting = "95696E2C-742F-4639-A9D4-5D36EE021C49"; } }