diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/DynamicInterfaceCastableHelpers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/DynamicInterfaceCastableHelpers.cs index 6ea1f1b444a507..da3cd7f1059e27 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/DynamicInterfaceCastableHelpers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/DynamicInterfaceCastableHelpers.cs @@ -34,7 +34,7 @@ internal static bool IsInterfaceImplemented(IDynamicInterfaceCastable castable, if (!implType.IsDefined(typeof(DynamicInterfaceCastableImplementationAttribute), inherit: false)) throw new InvalidOperationException(SR.Format(SR.IDynamicInterfaceCastable_MissingImplementationAttribute, implType, nameof(DynamicInterfaceCastableImplementationAttribute))); - if (!implType.ImplementInterface(interfaceType)) + if (!implType.IsAssignableTo(interfaceType)) throw new InvalidOperationException(SR.Format(SR.IDynamicInterfaceCastable_DoesNotImplementRequested, implType, interfaceType)); return implType; diff --git a/src/tests/Interop/IDynamicInterfaceCastable/Program.cs b/src/tests/Interop/IDynamicInterfaceCastable/Program.cs index 2cf471f0464316..c334a31e38ac98 100644 --- a/src/tests/Interop/IDynamicInterfaceCastable/Program.cs +++ b/src/tests/Interop/IDynamicInterfaceCastable/Program.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using TestLibrary; @@ -25,9 +26,9 @@ public interface ITest int CallImplemented(ImplementationToCall toCall); } - public interface ITestGeneric + public interface ITestGeneric { - T ReturnArg(T t); + U ReturnArg(T t); } public interface IDirectlyImplemented @@ -106,18 +107,24 @@ Type ITest.GetMyType() } [DynamicInterfaceCastableImplementation] - public interface ITestGenericImpl: ITestGeneric + public interface ITestGenericImpl: ITestGeneric { - T ITestGeneric.ReturnArg(T t) + U ITestGeneric.ReturnArg(T t) { - return t; + if (!typeof(T).IsAssignableTo(typeof(U)) + && !t.GetType().IsAssignableTo(typeof(U))) + { + throw new Exception($"Invalid covariance conversion from {typeof(T)} or {t.GetType()} to {typeof(U)}"); + } + + return Unsafe.As(ref t); } } [DynamicInterfaceCastableImplementation] - public interface ITestGenericIntImpl: ITestGeneric + public interface ITestGenericIntImpl: ITestGeneric { - int ITestGeneric.ReturnArg(int i) + int ITestGeneric.ReturnArg(int i) { return i; } @@ -302,27 +309,34 @@ private static void ValidateGenericInterface() Console.WriteLine($"Running {nameof(ValidateGenericInterface)}"); object castableObj = new DynamicInterfaceCastable(new Dictionary { - { typeof(ITestGeneric), typeof(ITestGenericIntImpl) }, - { typeof(ITestGeneric), typeof(ITestGenericImpl) }, + { typeof(ITestGeneric), typeof(ITestGenericIntImpl) }, + { typeof(ITestGeneric), typeof(ITestGenericImpl) }, + { typeof(ITestGeneric), typeof(ITestGenericImpl) }, }); Console.WriteLine(" -- Validate cast"); - // ITestGeneric -> ITestGenericIntImpl - Assert.IsTrue(castableObj is ITestGeneric, $"Should be castable to {nameof(ITestGeneric)} via is"); - Assert.IsNotNull(castableObj as ITestGeneric, $"Should be castable to {nameof(ITestGeneric)} via as"); - ITestGeneric testInt = (ITestGeneric)castableObj; + // ITestGeneric -> ITestGenericIntImpl + Assert.IsTrue(castableObj is ITestGeneric, $"Should be castable to {nameof(ITestGeneric)} via is"); + Assert.IsNotNull(castableObj as ITestGeneric, $"Should be castable to {nameof(ITestGeneric)} via as"); + ITestGeneric testInt = (ITestGeneric)castableObj; + + // ITestGeneric -> ITestGenericImpl + Assert.IsTrue(castableObj is ITestGeneric, $"Should be castable to {nameof(ITestGeneric)} via is"); + Assert.IsNotNull(castableObj as ITestGeneric, $"Should be castable to {nameof(ITestGeneric)} via as"); + ITestGeneric testStr = (ITestGeneric)castableObj; - // ITestGeneric -> ITestGenericImpl - Assert.IsTrue(castableObj is ITestGeneric, $"Should be castable to {nameof(ITestGeneric)} via is"); - Assert.IsNotNull(castableObj as ITestGeneric, $"Should be castable to {nameof(ITestGeneric)} via as"); - ITestGeneric testStr = (ITestGeneric)castableObj; + // Validate Variance + // ITestGeneric -> ITestGenericImpl + Assert.IsTrue(castableObj is ITestGeneric, $"Should be castable to {nameof(ITestGeneric)} via is"); + Assert.IsNotNull(castableObj as ITestGeneric, $"Should be castable to {nameof(ITestGeneric)} via as"); + ITestGeneric testVar = (ITestGeneric)castableObj; - // ITestGeneric is not recognized - Assert.IsFalse(castableObj is ITestGeneric, $"Should not be castable to {nameof(ITestGeneric)} via is"); - Assert.IsNull(castableObj as ITestGeneric, $"Should not be castable to {nameof(ITestGeneric)} via as"); - var ex = Assert.Throws(() => { var _ = (ITestGeneric)castableObj; }); - Assert.AreEqual(string.Format(DynamicInterfaceCastableException.ErrorFormat, typeof(ITestGeneric)), ex.Message); + // ITestGeneric is not recognized + Assert.IsFalse(castableObj is ITestGeneric, $"Should not be castable to {nameof(ITestGeneric)} via is"); + Assert.IsNull(castableObj as ITestGeneric, $"Should not be castable to {nameof(ITestGeneric)} via as"); + var ex = Assert.Throws(() => { var _ = (ITestGeneric)castableObj; }); + Assert.AreEqual(string.Format(DynamicInterfaceCastableException.ErrorFormat, typeof(ITestGeneric)), ex.Message); int expectedInt = 42; string expectedStr = "str"; @@ -330,12 +344,15 @@ private static void ValidateGenericInterface() Console.WriteLine(" -- Validate method call"); Assert.AreEqual(expectedInt, testInt.ReturnArg(42)); Assert.AreEqual(expectedStr, testStr.ReturnArg(expectedStr)); + Assert.AreEqual(expectedStr, testVar.ReturnArg(expectedStr)); Console.WriteLine(" -- Validate delegate call"); Func funcInt = new Func(testInt.ReturnArg); Assert.AreEqual(expectedInt, funcInt(expectedInt)); Func funcStr = new Func(testStr.ReturnArg); Assert.AreEqual(expectedStr, funcStr(expectedStr)); + Func funcVar = new Func(testVar.ReturnArg); + Assert.AreEqual(expectedStr, funcVar(expectedStr)); } private static void ValidateOverriddenInterface()