From e9b37191b5c0c9c9bf636bcfac3652bb66cd4439 Mon Sep 17 00:00:00 2001 From: DoctorKrolic Date: Fri, 13 Mar 2026 19:54:48 +0300 Subject: [PATCH 1/4] Port `ComClassGenerator` to string writing instead of constructing syntax nodes --- .../ComClassGenerator.cs | 213 +++++------------- .../gen/ComInterfaceGenerator/ComClassInfo.cs | 8 +- .../ContainingSyntaxContext.cs | 39 ++++ .../TypeNames.cs | 28 --- 4 files changed, 98 insertions(+), 190 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs index 4f8940ca0df19f..10260378e35f5f 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs @@ -1,8 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Collections.Generic; -using System.Collections.Immutable; +using System.CodeDom.Compiler; using System.IO; using System.Linq; using Microsoft.CodeAnalysis; @@ -10,18 +9,18 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; using Microsoft.Interop.Analyzers; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -using static Microsoft.Interop.SyntaxFactoryExtensions; namespace Microsoft.Interop { [Generator] public class ComClassGenerator : IIncrementalGenerator { + private const string ClassInfoTypeName = "ComClassInformation"; + public void Initialize(IncrementalGeneratorInitializationContext context) { // Get all types with the [GeneratedComClassAttribute] attribute. - var attributedClasses = context.SyntaxProvider + IncrementalValuesProvider attributedClasses = context.SyntaxProvider .ForAttributeWithMetadataName( TypeNames.GeneratedComClassAttribute, static (node, ct) => node is ClassDeclarationSyntax, @@ -29,8 +28,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { var type = (INamedTypeSymbol)context.TargetSymbol; var syntax = (ClassDeclarationSyntax)context.TargetNode; - var compilation = context.SemanticModel.Compilation; - var unsafeCodeIsEnabled = compilation.Options is CSharpCompilationOptions { AllowUnsafe: true }; + Compilation compilation = context.SemanticModel.Compilation; + bool unsafeCodeIsEnabled = compilation.Options is CSharpCompilationOptions { AllowUnsafe: true }; INamedTypeSymbol? generatedComInterfaceAttributeType = compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute); // Currently all reported diagnostics are fatal to the generator @@ -43,169 +42,61 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }) .Where(static info => info is not null); - var classInfoType = attributedClasses - .Select(static (info, ct) => new ItemAndSyntaxes(info, - [ - GenerateClassInfoType(info.ImplementedInterfacesNames.Array).NormalizeWhitespace(), - GenerateClassInfoAttributeOnUserType(info.ContainingSyntaxContext, info.ClassSyntax).NormalizeWhitespace() - ])); - - context.RegisterSourceOutput(classInfoType, static (context, data) => + context.RegisterSourceOutput(attributedClasses, (context, data) => { - var className = data.Context.ClassName; - var classInfoType = data[0]; - var attribute = data[1]; + string className = data.ClassName; + SequenceEqualImmutableArray implementedInterfaces = data.ImplementedInterfacesNames; - StringWriter writer = new(); + using StringWriter sw = new(); + using IndentedTextWriter writer = new(sw); writer.WriteLine("// "); - writer.WriteLine(classInfoType.ToFullString()); - writer.WriteLine(); - writer.WriteLine(attribute); - // Replace < and > with { and } to make valid hint names for generic types - string hintName = className.Replace('<', '{').Replace('>', '}'); - context.AddSource(hintName, writer.ToString()); - }); - } - - private const string ClassInfoTypeName = "ComClassInformation"; - - private static readonly AttributeSyntax s_comExposedClassAttributeTemplate = - Attribute( - GenericName(TypeNames.GlobalAlias + TypeNames.ComExposedClassAttribute) - .AddTypeArgumentListArguments( - IdentifierName(ClassInfoTypeName))); - private static MemberDeclarationSyntax GenerateClassInfoAttributeOnUserType(ContainingSyntaxContext containingSyntaxContext, ContainingSyntax classSyntax) => - containingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier( - TypeDeclaration(classSyntax.TypeKind, classSyntax.Identifier) - .WithModifiers(classSyntax.Modifiers) - .WithTypeParameterList(classSyntax.TypeParameters) - .AddAttributeLists(AttributeList(SingletonSeparatedList(s_comExposedClassAttributeTemplate)))); - private static ClassDeclarationSyntax GenerateClassInfoType(ImmutableArray implementedInterfaces) - { - const string vtablesField = "s_vtables"; - const string vtablesLocal = "vtables"; - const string detailsTempLocal = "details"; - const string countIdentifier = "count"; - var typeDeclaration = ClassDeclaration(ClassInfoTypeName) - .AddModifiers( - Token(SyntaxKind.FileKeyword), - Token(SyntaxKind.SealedKeyword), - Token(SyntaxKind.UnsafeKeyword)) - .AddBaseListTypes(SimpleBaseType(TypeSyntaxes.IComExposedClass)) - .AddMembers( - FieldDeclaration( - VariableDeclaration( - PointerType(TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry), - SingletonSeparatedList(VariableDeclarator(vtablesField)))) - .AddModifiers( - Token(SyntaxKind.PrivateKeyword), - Token(SyntaxKind.StaticKeyword), - Token(SyntaxKind.VolatileKeyword))); - List vtableInitializationBlock = new() - { - // ComInterfaceEntry* vtables = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(), sizeof(ComInterfaceEntry) * ); - Declare( - PointerType(TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry), - vtablesLocal, - CastExpression( - PointerType(TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry), - MethodInvocation( - TypeSyntaxes.System_Runtime_CompilerServices_RuntimeHelpers, - IdentifierName("AllocateTypeAssociatedMemory"), - Argument(TypeOfExpression(IdentifierName(ClassInfoTypeName))), - Argument( - BinaryExpression( - SyntaxKind.MultiplyExpression, - SizeOfExpression(TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry), - LiteralExpression( - SyntaxKind.NumericLiteralExpression, - Literal(implementedInterfaces.Length))))))), - // IIUnknownDerivedDetails details; - Declare(TypeSyntaxes.IIUnknownDerivedDetails, detailsTempLocal, initializeToDefault: false) - }; - for (int i = 0; i < implementedInterfaces.Length; i++) - { - string ifaceName = implementedInterfaces[i]; + writer.WriteLine($"file sealed unsafe class {ClassInfoTypeName} : global::System.Runtime.InteropServices.Marshalling.IComExposedClass"); + writer.WriteLine('{'); + writer.Indent++; + writer.WriteLine("private static volatile global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* s_vtables;"); + sw.WriteLine(); + writer.WriteLine("public static global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count)"); + writer.WriteLine('{'); + writer.Indent++; + writer.WriteLine($"count = {implementedInterfaces.Length};"); + writer.WriteLine("if (s_vtables == null)"); + writer.WriteLine('{'); + writer.Indent++; + writer.WriteLine($"global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* vtables = (global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry*)global::System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof({ClassInfoTypeName}), sizeof(global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry) * {implementedInterfaces.Length});"); + writer.WriteLine("global::System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails details;"); + sw.WriteLine(); + for (int i = 0; i < implementedInterfaces.Length; i++) + { + string ifaceName = implementedInterfaces[i]; - // details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof().TypeHandle); - vtableInitializationBlock.Add( - AssignmentStatement( - IdentifierName(detailsTempLocal), - MethodInvocation( - TypeSyntaxes.StrategyBasedComWrappers - .Dot(IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")), - IdentifierName("GetIUnknownDerivedDetails"), - Argument( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - TypeOfExpression(ParseName(ifaceName)), - IdentifierName("TypeHandle")))))); - // vtable[i] = new() { IID = details.Iid, Vtable = details.ManagedVirtualMethodTable }; - vtableInitializationBlock.Add( - AssignmentStatement( - IndexExpression( - IdentifierName(vtablesLocal), - Argument(IntLiteral(i))), - ImplicitObjectCreationExpression( - ArgumentList(), - InitializerExpression(SyntaxKind.ObjectInitializerExpression, - SeparatedList( - new ExpressionSyntax[] - { - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName("IID"), - IdentifierName(detailsTempLocal) - .Dot(IdentifierName("Iid"))), - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName("Vtable"), - CastExpression( - IdentifierName("nint"), - IdentifierName(detailsTempLocal) - .Dot(IdentifierName("ManagedVirtualMethodTable")))) - }))))); - } + writer.WriteLine($"details = global::System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof({ifaceName}).TypeHandle);"); + writer.WriteLine($"vtables[{i}] = new() {{ IID = details.Iid, Vtable = (nint)details.ManagedVirtualMethodTable }};"); + sw.WriteLine(); + } - // s_vtable = vtable; - vtableInitializationBlock.Add( - ExpressionStatement( - AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(vtablesField), - IdentifierName(vtablesLocal)))); + writer.WriteLine("s_vtables = vtables;"); + writer.Indent--; + writer.WriteLine('}'); + sw.WriteLine(); + writer.WriteLine("return s_vtables;"); + writer.Indent--; + writer.WriteLine('}'); + writer.Indent--; + writer.WriteLine('}'); - BlockSyntax getComInterfaceEntriesMethodBody = Block( - // count = ; - ExpressionStatement( - AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(countIdentifier), - LiteralExpression(SyntaxKind.NumericLiteralExpression, - Literal(implementedInterfaces.Length)))), - // if (s_vtable == null) - // { initializer block } - IfStatement( - BinaryExpression(SyntaxKind.EqualsExpression, - IdentifierName(vtablesField), - LiteralExpression(SyntaxKind.NullLiteralExpression)), - Block(vtableInitializationBlock)), - // return s_vtable; - ReturnStatement(IdentifierName(vtablesField))); + sw.WriteLine(); - typeDeclaration = typeDeclaration.AddMembers( - // public static unsafe ComWrappers.ComInterfaceDispatch* GetComInterfaceEntries(out int count) - // { body } - MethodDeclaration( - PointerType( - TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry), - "GetComInterfaceEntries") - .AddParameterListParameters( - Parameter(Identifier(countIdentifier)) - .WithType(PredefinedType(Token(SyntaxKind.IntKeyword))) - .AddModifiers(Token(SyntaxKind.OutKeyword))) - .WithBody(getComInterfaceEntriesMethodBody) - .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword))); + data.ContainingSyntaxContext.WriteToWithUnsafeModifier(writer, data.ClassSyntax, static (writer, classSyntax) => + { + writer.WriteLine($"[global::System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute<{ClassInfoTypeName}>]"); + writer.WriteLine($"{string.Join(" ", classSyntax.Modifiers)} class {classSyntax.Identifier}{classSyntax.TypeParameters} {{ }}"); + }); - return typeDeclaration; + // Replace < and > with { and } to make valid hint names for generic types + string hintName = className.Replace('<', '{').Replace('>', '}'); + context.AddSource(hintName, sw.ToString()); + }); } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs index 838c23c01644bc..58a3caa495ea58 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Immutable; using System.Linq; using Microsoft.CodeAnalysis; @@ -8,7 +9,7 @@ namespace Microsoft.Interop { - internal sealed record ComClassInfo + internal sealed class ComClassInfo : IEquatable { public string ClassName { get; init; } public ContainingSyntaxContext ContainingSyntaxContext { get; init; } @@ -54,6 +55,11 @@ public bool Equals(ComClassInfo? other) && ImplementedInterfacesNames.SequenceEqual(other.ImplementedInterfacesNames); } + public override bool Equals(object obj) + { + return Equals(obj as ComClassInfo); + } + public override int GetHashCode() { return HashCode.Combine(ClassName, ContainingSyntaxContext, ImplementedInterfacesNames); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ContainingSyntaxContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ContainingSyntaxContext.cs index 1ac5effeaced95..308ef9aeb54a02 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ContainingSyntaxContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ContainingSyntaxContext.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.CodeDom.Compiler; using System.Collections.Immutable; using System.Linq; using System.Text; @@ -149,5 +150,43 @@ public MemberDeclarationSyntax WrapMembersInContainingSyntaxWithUnsafeModifier(p } return wrappedMember; } + + public void WriteToWithUnsafeModifier(IndentedTextWriter writer, TState writeMembersState, Action writeMembers) + { + if (ContainingNamespace is not null) + { + writer.WriteLine($"namespace {ContainingNamespace}"); + writer.WriteLine('{'); + writer.Indent++; + } + + // When creating syntax we walk from most nested type to least nested and then enclose this chain in a namespace. + // With string writing things are exactly opposite: we are starting with a namespace and then print headers of types + // from least nested to most nested one. Since syntax model was the original one we have containing syntaxes stored as + // most convenient for it, so for string writing we have to walk them in the reverse order. When we eventually port + // our source generation to string writing we should reverse the order of elements for the convenience of that model instead. + for (int i = ContainingSyntax.Length - 1; i >= 0; i--) + { + ContainingSyntax syntax = ContainingSyntax[i]; + + writer.WriteLine($"{string.Join(" ", syntax.Modifiers.AddToModifiers(SyntaxKind.UnsafeKeyword))} {SyntaxFacts.GetText(syntax.TypeKind)} {syntax.Identifier}{syntax.TypeParameters}"); + writer.WriteLine('{'); + writer.Indent++; + } + + writeMembers(writer, writeMembersState); + + for (int i = 0; i < ContainingSyntax.Length; i++) + { + writer.Indent--; + writer.WriteLine('}'); + } + + if (ContainingNamespace is not null) + { + writer.Indent--; + writer.WriteLine('}'); + } + } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs index 435a1c3268cdba..f906e506bdec31 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs @@ -13,15 +13,9 @@ public static class NameSyntaxes private static NameSyntax? _DllImportAttribute; public static NameSyntax DllImportAttribute => _DllImportAttribute ??= ParseName(TypeNames.GlobalAlias + TypeNames.DllImportAttribute); - private static NameSyntax? _LibraryImportAttribute; - public static NameSyntax LibraryImportAttribute => _LibraryImportAttribute ??= ParseName(TypeNames.GlobalAlias + TypeNames.LibraryImportAttribute); - private static NameSyntax? _System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute; public static NameSyntax System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute => _System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute ??= ParseName(TypeNames.GlobalAlias + TypeNames.System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute); - private static NameSyntax? _System_Runtime_InteropServices_MarshalAsAttribute; - public static NameSyntax System_Runtime_InteropServices_MarshalAsAttribute => _System_Runtime_InteropServices_MarshalAsAttribute ??= ParseName(TypeNames.GlobalAlias + TypeNames.System_Runtime_InteropServices_MarshalAsAttribute); - private static NameSyntax? _DefaultDllImportSearchPathsAttribute; public static NameSyntax DefaultDllImportSearchPathsAttribute => _DefaultDllImportSearchPathsAttribute ??= ParseName(TypeNames.GlobalAlias + TypeNames.DefaultDllImportSearchPathsAttribute); @@ -54,20 +48,10 @@ public static class NameSyntaxes public static class TypeSyntaxes { - public static TypeSyntax Void { get; } = PredefinedType(Token(SyntaxKind.VoidKeyword)); - public static TypeSyntax VoidStar { get; } = PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))); public static TypeSyntax VoidStarStar { get; } = PointerType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))); - public static TypeSyntax Nint { get; } = ParseTypeName(TypeNames.Nint); - - private static TypeSyntax? _StringMarshalling; - public static TypeSyntax StringMarshalling => _StringMarshalling ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.StringMarshalling); - - private static TypeSyntax? _System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry; - public static TypeSyntax System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry => _System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry); - private static TypeSyntax? _System_Runtime_InteropServices_NativeMemory; public static TypeSyntax System_Runtime_InteropServices_NativeMemory => _System_Runtime_InteropServices_NativeMemory ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.System_Runtime_InteropServices_NativeMemory); @@ -80,21 +64,12 @@ public static class TypeSyntaxes private static TypeSyntax? _IIUnknownInterfaceType; public static TypeSyntax IIUnknownInterfaceType => _IIUnknownInterfaceType ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.IIUnknownInterfaceType); - private static TypeSyntax? _IIUnknownDerivedDetails; - public static TypeSyntax IIUnknownDerivedDetails => _IIUnknownDerivedDetails ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.IIUnknownDerivedDetails); - private static TypeSyntax? _UnmanagedObjectUnwrapper; public static TypeSyntax UnmanagedObjectUnwrapper => _UnmanagedObjectUnwrapper ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.UnmanagedObjectUnwrapper); - private static TypeSyntax? _IComExposedClass; - public static TypeSyntax IComExposedClass => _IComExposedClass ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.IComExposedClass); - private static TypeSyntax? _UnreachableException; public static TypeSyntax UnreachableException => _UnreachableException ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.UnreachableException); - private static TypeSyntax? _System_Runtime_CompilerServices_RuntimeHelpers; - public static TypeSyntax System_Runtime_CompilerServices_RuntimeHelpers => _System_Runtime_CompilerServices_RuntimeHelpers ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.System_Runtime_CompilerServices_RuntimeHelpers); - private static TypeSyntax? _System_Runtime_InteropServices_ComWrappers; public static TypeSyntax System_Runtime_InteropServices_ComWrappers => _System_Runtime_InteropServices_ComWrappers ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.System_Runtime_InteropServices_ComWrappers); @@ -110,9 +85,6 @@ public static class TypeSyntaxes private static TypeSyntax? _System_Type; public static TypeSyntax System_Type => _System_Type ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.System_Type); - private static TypeSyntax? _System_Activator; - public static TypeSyntax System_Activator => _System_Activator ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.System_Activator); - private static TypeSyntax? _System_Runtime_InteropServices_Marshal; public static TypeSyntax System_Runtime_InteropServices_Marshal => _System_Runtime_InteropServices_Marshal ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.System_Runtime_InteropServices_Marshal); From 3d5541ccd5557a4f73f62f85d0050b6e3c42665b Mon Sep 17 00:00:00 2001 From: DoctorKrolic Date: Fri, 13 Mar 2026 19:57:50 +0300 Subject: [PATCH 2/4] Make lambda static --- .../gen/ComInterfaceGenerator/ComClassGenerator.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs index 10260378e35f5f..fb000dc2bc74a6 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs @@ -42,7 +42,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }) .Where(static info => info is not null); - context.RegisterSourceOutput(attributedClasses, (context, data) => + context.RegisterSourceOutput(attributedClasses, static (context, data) => { string className = data.ClassName; SequenceEqualImmutableArray implementedInterfaces = data.ImplementedInterfacesNames; From 48884b5ee572b8284736d9d67979539afc8e49d0 Mon Sep 17 00:00:00 2001 From: DoctorKrolic Date: Fri, 13 Mar 2026 22:28:44 +0300 Subject: [PATCH 3/4] Fix --- .../ContainingSyntaxContext.cs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ContainingSyntaxContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ContainingSyntaxContext.cs index 308ef9aeb54a02..df73ba5b13cc25 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ContainingSyntaxContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ContainingSyntaxContext.cs @@ -169,7 +169,18 @@ public void WriteToWithUnsafeModifier(IndentedTextWriter writer, TState { ContainingSyntax syntax = ContainingSyntax[i]; - writer.WriteLine($"{string.Join(" ", syntax.Modifiers.AddToModifiers(SyntaxKind.UnsafeKeyword))} {SyntaxFacts.GetText(syntax.TypeKind)} {syntax.Identifier}{syntax.TypeParameters}"); + string declarationKeyword = syntax.TypeKind switch + { + SyntaxKind.ClassDeclaration => "class", + SyntaxKind.StructDeclaration => "struct", + SyntaxKind.InterfaceDeclaration => "interface", + SyntaxKind.EnumDeclaration => "enum", + SyntaxKind.RecordDeclaration => "record", + SyntaxKind.RecordStructDeclaration => "record struct", + _ => throw new UnreachableException(), + }; + + writer.WriteLine($"{string.Join(" ", syntax.Modifiers.AddToModifiers(SyntaxKind.UnsafeKeyword))} {declarationKeyword} {syntax.Identifier}{syntax.TypeParameters}"); writer.WriteLine('{'); writer.Indent++; } From 6cabfe6444441b4b54468d3500bb8b73e13a0002 Mon Sep 17 00:00:00 2001 From: DoctorKrolic Date: Fri, 13 Mar 2026 23:54:09 +0300 Subject: [PATCH 4/4] Add a unit test for a nested COM class --- .../ContainingSyntaxContext.cs | 1 - .../ComClassGeneratorOutputShape.cs | 28 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ContainingSyntaxContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ContainingSyntaxContext.cs index df73ba5b13cc25..832be0b79db305 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ContainingSyntaxContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ContainingSyntaxContext.cs @@ -174,7 +174,6 @@ public void WriteToWithUnsafeModifier(IndentedTextWriter writer, TState SyntaxKind.ClassDeclaration => "class", SyntaxKind.StructDeclaration => "struct", SyntaxKind.InterfaceDeclaration => "interface", - SyntaxKind.EnumDeclaration => "enum", SyntaxKind.RecordDeclaration => "record", SyntaxKind.RecordStructDeclaration => "record struct", _ => throw new UnreachableException(), diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs index 1b4dcc94807397..6a924980bdd939 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs @@ -89,6 +89,34 @@ partial interface INativeAPI await VerifySourceGeneratorAsync(source, "GenericClass`1"); } + [Theory] + [InlineData("class")] + [InlineData("struct")] + [InlineData("interface")] + [InlineData("record")] + [InlineData("record class")] + [InlineData("record struct")] + public async Task NestedComClass(string containingTypeKeyword) + { + string source = $$""" + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + partial interface INativeAPI + { + } + + partial {{containingTypeKeyword}} ContainingType + { + [GeneratedComClass] + partial class C : INativeAPI {} + } + """; + + await VerifySourceGeneratorAsync(source, "ContainingType+C"); + } + private static async Task VerifySourceGeneratorAsync(string source, params string[] typeNames) { GeneratedShapeTest test = new(typeNames)