diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 4d71dc9d788a99..116ee786fb86cd 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -150,11 +150,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) GenerateImplementationInterface(x, ct).NormalizeWhitespace(), GenerateInterfaceImplementationVtable(x, ct).NormalizeWhitespace(), GenerateImplementationVTableMethods(x, ct).NormalizeWhitespace(), - x.Interface.Info.TypeDefinitionContext.WrapMemberInContainingSyntaxWithUnsafeModifier(TypeDeclaration(x.Interface.Info.ContainingSyntax.TypeKind, x.Interface.Info.ContainingSyntax.Identifier) - .WithModifiers(x.Interface.Info.ContainingSyntax.Modifiers) - .WithTypeParameterList(x.Interface.Info.ContainingSyntax.TypeParameters) - .WithMembers(List(x.ShadowingMethods.Select(m => m.Shadow)))) - .NormalizeWhitespace(), + GenerateShadowingMethodsInterface(x), GenerateImplementationVTable(x, ct).NormalizeWhitespace(), GenerateInterfaceInformation(x.Interface.Info, ct).NormalizeWhitespace(), GenerateIUnknownDerivedAttributeApplication(x.Interface.Info, ct).NormalizeWhitespace() @@ -231,6 +227,17 @@ private static MemberDeclarationSyntax GenerateIUnknownDerivedAttributeApplicati .WithTypeParameterList(context.ContainingSyntax.TypeParameters) .AddAttributeLists(AttributeList(SingletonSeparatedList(s_iUnknownDerivedAttributeTemplate)))); + private static MemberDeclarationSyntax GenerateShadowingMethodsInterface(ComInterfaceAndMethodsContext interfaceMethods) + { + var containingSyntax = interfaceMethods.Interface.Info.ContainingSyntax; + return interfaceMethods.Interface.Info.TypeDefinitionContext.WrapMemberInContainingSyntaxWithUnsafeModifier( + TypeDeclaration(containingSyntax.TypeKind, containingSyntax.Identifier) + .WithModifiers(containingSyntax.Modifiers.AddToModifiers(SyntaxKind.UnsafeKeyword)) + .WithTypeParameterList(containingSyntax.TypeParameters) + .WithMembers(List(interfaceMethods.ShadowingMethods.Select(m => m.Shadow)))) + .NormalizeWhitespace(); + } + private static bool IsHResultLikeType(ManagedTypeInfo type) { string typeName = type.FullTypeName.Split('.', ':')[^1]; diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs index 514b400c3f1c91..ad68b2dfb5761e 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs @@ -701,6 +701,25 @@ unsafe partial interface INativeDerived : INativeAPIBase } """; + public string DerivedComInterfaceTypeWithUnsafeBaseMethod => $$""" + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [assembly:DisableRuntimeMarshalling] + + {{GeneratedComInterface()}} + partial interface IComInterfaceBase + { + unsafe void Method(void* pBuffer); + } + {{GeneratedComInterface()}} + partial interface IComInterfaceDerived : IComInterfaceBase + { + void Method2(); + } + """; + public class ManagedToUnmanaged : IVirtualMethodIndexSignatureProvider { public MarshalDirection Direction => MarshalDirection.ManagedToUnmanaged; diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs index b255a2a29a44df..c0806b2853d8f5 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs @@ -350,6 +350,7 @@ public static IEnumerable ComInterfaceSnippetsToCompile() yield return new object[] { ID(), codeSnippets.ForwarderWithPreserveSigAndRefKind("in") }; yield return new object[] { ID(), codeSnippets.ForwarderWithPreserveSigAndRefKind("out") }; yield return new object[] { ID(), codeSnippets.ComInterfaceWithNativeMarshalling }; + yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeWithUnsafeBaseMethod }; } public static IEnumerable ManagedToUnmanagedComInterfaceSnippetsToCompile()