Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -542,12 +542,39 @@ public BasicClassFactory(Guid clsid, [DynamicallyAccessedMembers(DynamicallyAcce
_classType = classType;
}

public static Type GetValidatedInterfaceType([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] Type classType, ref Guid riid, object? outer)
public enum ValidatedInterfaceKind
{
IUnknown,
IDispatch,
ManagedType,
}

public struct ValidatedInterfaceType
{
public ValidatedInterfaceKind Kind { get; init; }
public Type? ManagedType { get; init; }
}

public static ValidatedInterfaceType CreateValidatedInterfaceType([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] Type classType, ref Guid riid, object? outer)
{
Debug.Assert(classType != null);
if (riid == Marshal.IID_IUnknown)
{
return typeof(object);
return new ValidatedInterfaceType() { Kind = ValidatedInterfaceKind.IUnknown, ManagedType = null };
}
else if (riid == Marshal.IID_IDispatch)
{
ClassInterfaceAttribute? attr =
classType.GetCustomAttribute<ClassInterfaceAttribute>()
?? classType.Assembly.GetCustomAttribute<ClassInterfaceAttribute>(); // If there is no attribute on the Type, check the Assembly.

// If the attribute is unspecified, the default is ClassInterfaceType.AutoDispatch.
// See DEFAULT_CLASS_INTERFACE_TYPE in native.
if (attr is null
|| attr.Value is ClassInterfaceType.AutoDispatch or ClassInterfaceType.AutoDual)
{
return new ValidatedInterfaceType() { Kind = ValidatedInterfaceKind.IDispatch, ManagedType = null };
}
}

// Aggregation can only be done when requesting IUnknown.
Expand All @@ -562,31 +589,38 @@ public static Type GetValidatedInterfaceType([DynamicallyAccessedMembers(Dynamic
{
if (i.GUID == riid)
{
return i;
return new ValidatedInterfaceType() { Kind = ValidatedInterfaceKind.ManagedType, ManagedType = i };
}
}

// E_NOINTERFACE
throw new InvalidCastException();
}

public static IntPtr GetObjectAsInterface(object obj, Type interfaceType)
public static IntPtr GetObjectAsInterface(object obj, ValidatedInterfaceType interfaceType)
{
// If the requested "interface type" is type object then return as IUnknown
if (interfaceType == typeof(object))
if (interfaceType.Kind is ValidatedInterfaceKind.IUnknown)
{
Debug.Assert(interfaceType.ManagedType is null);
return Marshal.GetIUnknownForObject(obj);
}
else if (interfaceType.Kind is ValidatedInterfaceKind.IDispatch)
{
Debug.Assert(interfaceType.ManagedType is null);
return Marshal.GetIDispatchForObject(obj);
}

Debug.Assert(interfaceType.IsInterface);
Debug.Assert(interfaceType.Kind is ValidatedInterfaceKind.ManagedType
&& interfaceType.ManagedType != null
&& interfaceType.ManagedType.IsInterface);

// The intent of this call is to get AND validate the interface can be
// marshalled to native code. An exception will be thrown if the
// type is unable to be marshalled to native code.
// Scenarios where this is relevant:
// - Interfaces that use Generics
// - Interfaces that define implementation
IntPtr interfaceMaybe = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore);
IntPtr interfaceMaybe = Marshal.GetComInterfaceForObject(obj, interfaceType.ManagedType, CustomQueryInterfaceMode.Ignore);

if (interfaceMaybe == IntPtr.Zero)
{
Expand Down Expand Up @@ -620,7 +654,7 @@ public void CreateInstance(
ref Guid riid,
out IntPtr ppvObject)
{
Type interfaceType = GetValidatedInterfaceType(_classType, ref riid, pUnkOuter);
var interfaceType = CreateValidatedInterfaceType(_classType, ref riid, pUnkOuter);

object obj = Activator.CreateInstance(_classType)!;
if (pUnkOuter != null)
Expand Down Expand Up @@ -700,7 +734,7 @@ private void CreateInstanceInner(
bool isDesignTime,
out IntPtr ppvObject)
{
Type interfaceType = BasicClassFactory.GetValidatedInterfaceType(_classType, ref riid, pUnkOuter);
var interfaceType = BasicClassFactory.CreateValidatedInterfaceType(_classType, ref riid, pUnkOuter);
Comment thread
AaronRobinsonMSFT marked this conversation as resolved.

object obj = _licenseProxy.AllocateAndValidateLicense(_classType, key, isDesignTime);
if (pUnkOuter != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ public static partial class Marshal
/// IUnknown is {00000000-0000-0000-C000-000000000046}
/// </summary>
internal static readonly Guid IID_IUnknown = new Guid(0, 0, 0, 0xC0, 0, 0, 0, 0, 0, 0, 0x46);

/// <summary>
/// IDispatch is {00020400-0000-0000-C000-000000000046}
/// </summary>
internal static readonly Guid IID_IDispatch = new Guid(0x00020400, 0, 0, 0xC0, 0, 0, 0, 0, 0, 0, 0x46);
#endif //FEATURE_COMINTEROP

internal static int SizeOfHelper(RuntimeType t, [MarshalAs(UnmanagedType.Bool)] bool throwIfNotMarshalable)
Expand Down
36 changes: 36 additions & 0 deletions src/tests/Interop/COM/NETServer/ClassInterfaceTesting.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;

#pragma warning disable 618 // Must test deprecated features

[ComVisible(true)]
[Guid(Server.Contract.Guids.ClassInterfaceNotSetTesting)]
public class ClassInterfaceNotSetTesting
{
}

[ComVisible(true)]
[Guid(Server.Contract.Guids.ClassInterfaceNoneTesting)]
[ClassInterface(ClassInterfaceType.None)]
public class ClassInterfaceNoneTesting
{
}

[ComVisible(true)]
[Guid(Server.Contract.Guids.ClassInterfaceAutoDispatchTesting)]
[ClassInterface(ClassInterfaceType.AutoDispatch)]
public class ClassInterfaceAutoDispatchTesting
{
}

[ComVisible(true)]
[Guid(Server.Contract.Guids.ClassInterfaceAutoDualTesting)]
[ClassInterface(ClassInterfaceType.AutoDual)]
public class ClassInterfaceAutoDualTesting
{
}

#pragma warning restore 618 // Must test deprecated features
16 changes: 16 additions & 0 deletions src/tests/Interop/COM/NativeClients/MiscTypes/CoreShim.X.manifest
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@
<comClass
clsid="{CCFF894B-A27C-45E0-9B30-6C88D722E843}"
threadingModel="Both" />
<!-- ClassInterfaceNotSetTesting -->
<comClass
clsid="{B8314D5A-DE70-435B-AD97-8F88820D1F3C}"
threadingModel="Both" />
<!-- ClassInterfaceNoneTesting -->
<comClass
clsid="{ED4D9C70-1C9F-406B-B51F-87DD977AF3B2}"
threadingModel="Both" />
<!-- ClassInterfaceAutoDispatchTesting -->
<comClass
clsid="{C1A0AE72-791B-4380-946E-B7BABDEA1701}"
threadingModel="Both" />
<!-- ClassInterfaceAutoDualTesting -->
<comClass
clsid="{95696E2C-742F-4639-A9D4-5D36EE021C49}"
threadingModel="Both" />
</file>

</assembly>
63 changes: 59 additions & 4 deletions src/tests/Interop/COM/NativeClients/MiscTypes/MiscTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct ComInit
using ComMTA = ComInit<COINIT_MULTITHREADED>;
void ValidationTests();
void ValidationByRefTests();
void ValidationClassInterfaceTests();

int __cdecl main()
{
Expand All @@ -59,6 +60,16 @@ int __cdecl main()
return 101;
}

try
{
ValidationClassInterfaceTests();
}
catch (HRESULT hr)
{
::printf("Test Failure: 0x%08x\n", hr);
return 101;
}

return 100;
}

Expand All @@ -84,7 +95,7 @@ void ValidationTests()

HRESULT hr;

IMiscTypesTesting *miscTypesTesting;
ComSmartPtr<IMiscTypesTesting> miscTypesTesting;
THROW_IF_FAILED(::CoCreateInstance(CLSID_MiscTypesTesting, nullptr, CLSCTX_INPROC, IID_IMiscTypesTesting, (void**)&miscTypesTesting));

::printf("-- Primitives <=> VARIANT...\n");
Expand Down Expand Up @@ -328,7 +339,7 @@ void ValidationByRefTests()

HRESULT hr;

IMiscTypesTesting *miscTypesTesting;
ComSmartPtr<IMiscTypesTesting> miscTypesTesting;
THROW_IF_FAILED(::CoCreateInstance(CLSID_MiscTypesTesting, nullptr, CLSCTX_INPROC, IID_IMiscTypesTesting, (void**)&miscTypesTesting));

::printf("-- Primitives <=> BYREF VARIANT...\n");
Expand Down Expand Up @@ -365,7 +376,7 @@ void ValidationByRefTests()
THROW_FAIL_IF_FALSE(CompareStringOrdinal(expected, -1, value, -1, FALSE) == CSTR_EQUAL);
::SysFreeString(expected);
}

::printf("-- System.Guid <=> BYREF VARIANT...\n");
{
/* 8EFAD956-B33D-46CB-90F4-45F55BA68A96 */
Expand All @@ -378,7 +389,7 @@ void ValidationByRefTests()
THROW_FAIL_IF_FALSE(memcmp(V_RECORD(&guidVar.Input), &expected, sizeof(expected)) == 0);
THROW_IF_FAILED(miscTypesTesting->Marshal_Instance_Variant(W("{00000000-0000-0000-0000-000000000000}"), &guidVar.Result));
THROW_FAIL_IF_FALSE(V_VT(&guidVar.Result) == VT_RECORD);

// Use the Guid as input.
VariantMarshalTest args{};
THROW_IF_FAILED(::VariantCopy(&args.Input, &guidVar.Input));
Expand All @@ -399,3 +410,47 @@ void ValidationByRefTests()
THROW_FAIL_IF_FALSE(miscTypesTesting->Marshal_ByRefVariant(&args.Result, args.Input) == 0x80004002); // COR_E_INVALIDCAST
}
}

void ValidationClassInterfaceTests()
{
::printf(__FUNCTION__ "() through CoCreateInstance...\n");

HRESULT hr;

{
CoreShimComActivation csact{ W("NETServer"), W("ClassInterfaceNotSetTesting") };
::printf("-- ClassInterfaceType not set ...\n");

ComSmartPtr<IUnknown> pUnk;
THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceNotSetTesting, nullptr, CLSCTX_INPROC, IID_IUnknown, (void**)&pUnk));
ComSmartPtr<IDispatch> pDisp;
THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceNotSetTesting, nullptr, CLSCTX_INPROC, IID_IDispatch, (void**)&pDisp));
}
{
CoreShimComActivation csact{ W("NETServer"), W("ClassInterfaceNoneTesting") };
::printf("-- ClassInterfaceType.None ...\n");

ComSmartPtr<IUnknown> pUnk;
THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceNoneTesting, nullptr, CLSCTX_INPROC, IID_IUnknown, (void**)&pUnk));
ComSmartPtr<IDispatch> pDisp;
THROW_FAIL_IF_FALSE(E_NOINTERFACE == ::CoCreateInstance(CLSID_ClassInterfaceNoneTesting, nullptr, CLSCTX_INPROC, IID_IDispatch, (void**)&pDisp));
}
{
CoreShimComActivation csact{ W("NETServer"), W("ClassInterfaceAutoDispatchTesting") };
::printf("-- ClassInterfaceType.AutoDispatch ...\n");

ComSmartPtr<IUnknown> pUnk;
THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceAutoDispatchTesting, nullptr, CLSCTX_INPROC, IID_IUnknown, (void**)&pUnk));
ComSmartPtr<IDispatch> pDisp;
THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceAutoDispatchTesting, nullptr, CLSCTX_INPROC, IID_IDispatch, (void**)&pDisp));
}
{
CoreShimComActivation csact{ W("NETServer"), W("ClassInterfaceAutoDualTesting") };
::printf("-- ClassInterfaceType.AutoDual ...\n");

ComSmartPtr<IUnknown> pUnk;
THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceAutoDualTesting, nullptr, CLSCTX_INPROC, IID_IUnknown, (void**)&pUnk));
ComSmartPtr<IDispatch> pDisp;
THROW_IF_FAILED(::CoCreateInstance(CLSID_ClassInterfaceAutoDualTesting, nullptr, CLSCTX_INPROC, IID_IDispatch, (void**)&pDisp));
}
}
8 changes: 8 additions & 0 deletions src/tests/Interop/COM/NativeServer/Servers.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class DECLSPEC_UUID("66DB7882-E2B0-471D-92C7-B2B52A0EA535") LicenseTesting;
class DECLSPEC_UUID("FAEF42AE-C1A4-419F-A912-B768AC2679EA") DefaultInterfaceTesting;
class DECLSPEC_UUID("CE137261-6F19-44F5-A449-EF963B3F987E") InspectableTesting;
class DECLSPEC_UUID("4F54231D-9E11-4C0B-8E0B-2EBD8B0E5811") TrackMyLifetimeTesting;
class DECLSPEC_UUID("B8314D5A-DE70-435B-AD97-8F88820D1F3C") ClassInterfaceNotSetTesting;
class DECLSPEC_UUID("ED4D9C70-1C9F-406B-B51F-87DD977AF3B2") ClassInterfaceNoneTesting;
class DECLSPEC_UUID("C1A0AE72-791B-4380-946E-B7BABDEA1701") ClassInterfaceAutoDispatchTesting;
class DECLSPEC_UUID("95696E2C-742F-4639-A9D4-5D36EE021C49") ClassInterfaceAutoDualTesting;

#define CLSID_NumericTesting __uuidof(NumericTesting)
#define CLSID_ArrayTesting __uuidof(ArrayTesting)
Expand All @@ -38,6 +42,10 @@ class DECLSPEC_UUID("4F54231D-9E11-4C0B-8E0B-2EBD8B0E5811") TrackMyLifetimeTesti
#define CLSID_DefaultInterfaceTesting __uuidof(DefaultInterfaceTesting)
#define CLSID_InspectableTesting __uuidof(InspectableTesting)
#define CLSID_TrackMyLifetimeTesting __uuidof(TrackMyLifetimeTesting)
#define CLSID_ClassInterfaceNotSetTesting __uuidof(ClassInterfaceNotSetTesting)
#define CLSID_ClassInterfaceNoneTesting __uuidof(ClassInterfaceNoneTesting)
#define CLSID_ClassInterfaceAutoDispatchTesting __uuidof(ClassInterfaceAutoDispatchTesting)
#define CLSID_ClassInterfaceAutoDualTesting __uuidof(ClassInterfaceAutoDualTesting)

#define IID_INumericTesting __uuidof(INumericTesting)
#define IID_IArrayTesting __uuidof(IArrayTesting)
Expand Down
4 changes: 4 additions & 0 deletions src/tests/Interop/COM/ServerContracts/ServerGuids.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,9 @@ internal sealed class Guids
public const string ConsumeNETServerTesting = "DE4ACF53-5957-4D31-8BE2-EA6C80683246";
public const string InspectableTesting = "CE137261-6F19-44F5-A449-EF963B3F987E";
public const string TrackMyLifetimeTesting = "4F54231D-9E11-4C0B-8E0B-2EBD8B0E5811";
public const string ClassInterfaceNotSetTesting = "B8314D5A-DE70-435B-AD97-8F88820D1F3C";
public const string ClassInterfaceNoneTesting = "ED4D9C70-1C9F-406B-B51F-87DD977AF3B2";
public const string ClassInterfaceAutoDispatchTesting = "C1A0AE72-791B-4380-946E-B7BABDEA1701";
public const string ClassInterfaceAutoDualTesting = "95696E2C-742F-4639-A9D4-5D36EE021C49";
}
}