From 6e9bbaaf6f9fdf9ffa9a814b4a32ca09095dcf94 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Feb 2026 19:56:04 +0000 Subject: [PATCH 1/4] Initial plan From 46c3a3ea14d51cbc2cc3953d721035ccc9cd0555 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Feb 2026 20:40:54 +0000 Subject: [PATCH 2/4] Move diagnostics from ComInterfaceGenerator and VtableIndexStubGenerator into separate analyzers Co-authored-by: jkoritzinsky <1571408+jkoritzinsky@users.noreply.github.com> --- ...omInterfaceGeneratorDiagnosticsAnalyzer.cs | 241 ++++++++++++++++++ .../VtableIndexStubDiagnosticsAnalyzer.cs | 143 +++++++++++ .../ComInterfaceGenerator.cs | 9 +- .../VtableIndexStubGenerator.cs | 59 +---- .../ByValueContentsMarshalling.cs | 2 +- .../CompileFails.cs | 2 +- 6 files changed, 391 insertions(+), 65 deletions(-) create mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComInterfaceGeneratorDiagnosticsAnalyzer.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/VtableIndexStubDiagnosticsAnalyzer.cs diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComInterfaceGeneratorDiagnosticsAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComInterfaceGeneratorDiagnosticsAnalyzer.cs new file mode 100644 index 00000000000000..2396ad2c1613c0 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComInterfaceGeneratorDiagnosticsAnalyzer.cs @@ -0,0 +1,241 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Threading; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; + +namespace Microsoft.Interop.Analyzers +{ + [DiagnosticAnalyzer(LanguageNames.CSharp)] + public class ComInterfaceGeneratorDiagnosticsAnalyzer : DiagnosticAnalyzer + { + public override ImmutableArray SupportedDiagnostics { get; } = + ImmutableArray.Create( + // Interface-level diagnostics + GeneratorDiagnostics.RequiresAllowUnsafeBlocks, + GeneratorDiagnostics.InvalidAttributedInterfaceGenericNotSupported, + GeneratorDiagnostics.InvalidAttributedInterfaceMissingPartialModifiers, + GeneratorDiagnostics.InvalidAttributedInterfaceNotAccessible, + GeneratorDiagnostics.InvalidAttributedInterfaceMissingGuidAttribute, + GeneratorDiagnostics.InvalidStringMarshallingMismatchBetweenBaseAndDerived, + GeneratorDiagnostics.InvalidOptionsOnInterface, + GeneratorDiagnostics.InvalidStringMarshallingConfigurationOnInterface, + GeneratorDiagnostics.InvalidExceptionToUnmanagedMarshallerType, + GeneratorDiagnostics.StringMarshallingCustomTypeNotAccessibleByGeneratedCode, + GeneratorDiagnostics.ExceptionToUnmanagedMarshallerNotAccessibleByGeneratedCode, + GeneratorDiagnostics.MultipleComInterfaceBaseTypes, + GeneratorDiagnostics.BaseInterfaceIsNotGenerated, + GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly, + // Method-level diagnostics + GeneratorDiagnostics.MethodNotDeclaredInAttributedInterface, + GeneratorDiagnostics.InstancePropertyDeclaredInInterface, + GeneratorDiagnostics.InstanceEventDeclaredInInterface, + GeneratorDiagnostics.CannotAnalyzeMethodPattern, + GeneratorDiagnostics.CannotAnalyzeInterfacePattern, + // Stub-level diagnostics + GeneratorDiagnostics.ConfigurationNotSupported, + GeneratorDiagnostics.InvalidStringMarshallingConfigurationOnMethod, + GeneratorDiagnostics.ParameterTypeNotSupported, + GeneratorDiagnostics.ReturnTypeNotSupported, + GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails, + GeneratorDiagnostics.ReturnTypeNotSupportedWithDetails, + GeneratorDiagnostics.ParameterConfigurationNotSupported, + GeneratorDiagnostics.ReturnConfigurationNotSupported, + GeneratorDiagnostics.MarshalAsParameterConfigurationNotSupported, + GeneratorDiagnostics.MarshalAsReturnConfigurationNotSupported, + GeneratorDiagnostics.ConfigurationValueNotSupported, + GeneratorDiagnostics.MarshallingAttributeConfigurationNotSupported, + GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo, + GeneratorDiagnostics.UnnecessaryReturnMarshallingInfo, + GeneratorDiagnostics.ComMethodManagedReturnWillBeOutVariable, + GeneratorDiagnostics.HResultTypeWillBeTreatedAsStruct, + GeneratorDiagnostics.SizeOfInCollectionMustBeDefinedAtCallOutParam, + GeneratorDiagnostics.SizeOfInCollectionMustBeDefinedAtCallReturnValue, + GeneratorDiagnostics.InvalidExceptionMarshallingConfiguration, + GeneratorDiagnostics.GeneratedComInterfaceUsageDoesNotFollowBestPractices); + + public override void Initialize(AnalysisContext context) + { + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); + context.EnableConcurrentExecution(); + context.RegisterCompilationStartAction(compilationContext => + { + INamedTypeSymbol? generatedComInterfaceAttrType = compilationContext.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute); + if (generatedComInterfaceAttrType is null) + return; + + StubEnvironment env = new StubEnvironment( + compilationContext.Compilation, + compilationContext.Compilation.GetEnvironmentFlags()); + + var interfaceSymbols = new ConcurrentBag<(InterfaceDeclarationSyntax Syntax, INamedTypeSymbol Symbol)>(); + + compilationContext.RegisterSymbolAction(symbolContext => + { + INamedTypeSymbol typeSymbol = (INamedTypeSymbol)symbolContext.Symbol; + if (typeSymbol.TypeKind != TypeKind.Interface) + return; + + foreach (AttributeData attr in typeSymbol.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttrType)) + { + // Use the syntax reference that contains the attribute application + // (important for partial types where the attribute is on a specific partial declaration) + var attrSyntaxRef = attr.ApplicationSyntaxReference; + if (attrSyntaxRef is not null) + { + var attrSyntax = attrSyntaxRef.GetSyntax(symbolContext.CancellationToken); + if (attrSyntax.Parent?.Parent is InterfaceDeclarationSyntax ifaceSyntax) + { + interfaceSymbols.Add((ifaceSyntax, typeSymbol)); + } + } + else + { + foreach (SyntaxReference syntaxRef in typeSymbol.DeclaringSyntaxReferences) + { + if (syntaxRef.GetSyntax(symbolContext.CancellationToken) is InterfaceDeclarationSyntax ifaceSyntax) + { + interfaceSymbols.Add((ifaceSyntax, typeSymbol)); + break; + } + } + } + break; + } + } + }, SymbolKind.NamedType); + + compilationContext.RegisterCompilationEndAction(endContext => + { + if (interfaceSymbols.IsEmpty) + return; + + AnalyzeInterfaces(endContext, env, interfaceSymbols.ToImmutableArray()); + }); + }); + } + + private static void AnalyzeInterfaces( + CompilationAnalysisContext context, + StubEnvironment env, + ImmutableArray<(InterfaceDeclarationSyntax Syntax, INamedTypeSymbol Symbol)> attributedInterfaces) + { + CancellationToken ct = context.CancellationToken; + + // This mirrors the analysis phase of ComInterfaceGenerator.Initialize. + List<(ComInterfaceInfo, INamedTypeSymbol)> interfaceInfos = new(); + HashSet<(ComInterfaceInfo, INamedTypeSymbol)> externalIfaces = new(ComInterfaceInfo.EqualityComparerForExternalIfaces.Instance); + List diags = new(); + + foreach (var (syntax, symbol) in attributedInterfaces) + { + var cii = ComInterfaceInfo.From(symbol, syntax, env, ct); + if (cii.HasDiagnostic) + { + foreach (var diag in cii.Diagnostics) + diags.Add(diag); + } + if (cii.HasValue) + interfaceInfos.Add(cii.Value); + + var externalBase = ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(symbol); + if (!externalBase.IsDefaultOrEmpty) + { + foreach (var b in externalBase) + externalIfaces.Add(b); + } + } + + interfaceInfos.AddRange(externalIfaces); + + var comInterfaceContexts = ComInterfaceContext.GetContexts(interfaceInfos.Select(i => i.Item1).ToImmutableArray(), ct); + + Dictionary methodSymbols = new(); + List> methods = new(); + + foreach (var cii in interfaceInfos) + { + var cmi = ComMethodInfo.GetMethodsFromInterface(cii, ct); + var inner = new List(); + foreach (var m in cmi) + { + if (m.HasDiagnostic) + { + foreach (var diag in m.Diagnostics) + diags.Add(diag); + } + if (m.HasValue) + { + inner.Add(m.Value.ComMethod); + methodSymbols.Add(m.Value.ComMethod, m.Value.Symbol); + } + } + methods.Add(inner); + } + + List<(ComInterfaceContext, SequenceEqualImmutableArray)> ifaceCtxs = new(); + for (int i = 0; i < interfaceInfos.Count; i++) + { + var cic = comInterfaceContexts[i]; + if (cic.HasDiagnostic) + { + foreach (var diag in cic.Diagnostics) + diags.Add(diag); + } + if (cic.HasValue) + { + ifaceCtxs.Add((cic.Value, methods[i].ToSequenceEqualImmutableArray())); + } + } + + // Report interface-level and method-level diagnostics + foreach (var diag in diags) + context.ReportDiagnostic(diag.ToDiagnostic()); + + var result = ComMethodContext.CalculateAllMethods(ifaceCtxs, ct); + + List methodContexts = new(); + foreach (var data in result) + { + methodContexts.Add(new ComMethodContext( + data.Method, + data.OwningInterface, + ComInterfaceGenerator.CalculateStubInformation( + data.Method.MethodInfo.Syntax, + methodSymbols[data.Method.MethodInfo], + data.Method.Index, + env, + data.OwningInterface.Info, + ct))); + } + + // Group method contexts by owning interface to match the generator's GroupComContextsForInterfaceGeneration + // and only report diagnostics for declared (non-inherited) methods. + var groupedByOwningInterface = methodContexts + .GroupBy(m => m.OwningInterface); + + foreach (var group in groupedByOwningInterface) + { + var declaredMethods = group.Where(static m => !m.IsInheritedMethod).ToList(); + + // Report diagnostics for managed-to-unmanaged and unmanaged-to-managed stubs, + // deduplicating diagnostics that are reported for both (matching the generator behavior). + var allStubDiags = declaredMethods + .SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics.Array) + .Union(declaredMethods.SelectMany(m => m.UnmanagedToManagedStub.Diagnostics.Array)); + + foreach (var diag in allStubDiags) + context.ReportDiagnostic(diag.ToDiagnostic()); + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/VtableIndexStubDiagnosticsAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/VtableIndexStubDiagnosticsAnalyzer.cs new file mode 100644 index 00000000000000..01af871b46328e --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/VtableIndexStubDiagnosticsAnalyzer.cs @@ -0,0 +1,143 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; + +namespace Microsoft.Interop.Analyzers +{ + [DiagnosticAnalyzer(LanguageNames.CSharp)] + public class VtableIndexStubDiagnosticsAnalyzer : DiagnosticAnalyzer + { + public override ImmutableArray SupportedDiagnostics { get; } = + ImmutableArray.Create( + GeneratorDiagnostics.InvalidAttributedMethodSignature, + GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers, + GeneratorDiagnostics.ReturnConfigurationNotSupported, + GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingUnmanagedObjectUnwrapperAttribute, + GeneratorDiagnostics.InvalidStringMarshallingConfigurationOnMethod, + GeneratorDiagnostics.InvalidExceptionMarshallingConfiguration, + GeneratorDiagnostics.ConfigurationNotSupported, + GeneratorDiagnostics.ParameterTypeNotSupported, + GeneratorDiagnostics.ReturnTypeNotSupported, + GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails, + GeneratorDiagnostics.ReturnTypeNotSupportedWithDetails, + GeneratorDiagnostics.ParameterConfigurationNotSupported, + GeneratorDiagnostics.MarshalAsParameterConfigurationNotSupported, + GeneratorDiagnostics.MarshalAsReturnConfigurationNotSupported, + GeneratorDiagnostics.ConfigurationValueNotSupported, + GeneratorDiagnostics.MarshallingAttributeConfigurationNotSupported, + GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo, + GeneratorDiagnostics.UnnecessaryReturnMarshallingInfo, + GeneratorDiagnostics.GeneratedComInterfaceUsageDoesNotFollowBestPractices); + + public override void Initialize(AnalysisContext context) + { + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); + context.EnableConcurrentExecution(); + context.RegisterCompilationStartAction(compilationContext => + { + INamedTypeSymbol? virtualMethodIndexAttrType = compilationContext.Compilation.GetBestTypeByMetadataName(TypeNames.VirtualMethodIndexAttribute); + if (virtualMethodIndexAttrType is null) + return; + + StubEnvironment env = new StubEnvironment( + compilationContext.Compilation, + compilationContext.Compilation.GetEnvironmentFlags()); + + compilationContext.RegisterSymbolAction(symbolContext => + { + IMethodSymbol method = (IMethodSymbol)symbolContext.Symbol; + AttributeData? virtualMethodIndexAttr = null; + foreach (AttributeData attr in method.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, virtualMethodIndexAttrType)) + { + virtualMethodIndexAttr = attr; + break; + } + } + + if (virtualMethodIndexAttr is null) + return; + + foreach (SyntaxReference syntaxRef in method.DeclaringSyntaxReferences) + { + if (syntaxRef.GetSyntax(symbolContext.CancellationToken) is MethodDeclarationSyntax methodSyntax) + { + AnalyzeMethod(symbolContext, methodSyntax, method, env); + break; + } + } + }, SymbolKind.Method); + }); + } + + private static void AnalyzeMethod(SymbolAnalysisContext context, MethodDeclarationSyntax methodSyntax, IMethodSymbol method, StubEnvironment env) + { + DiagnosticInfo? invalidMethodDiagnostic = GetDiagnosticIfInvalidMethodForGeneration(methodSyntax, method); + if (invalidMethodDiagnostic is not null) + { + context.ReportDiagnostic(invalidMethodDiagnostic.ToDiagnostic()); + return; + } + + SourceAvailableIncrementalMethodStubGenerationContext stubContext = VtableIndexStubGenerator.CalculateStubInformation(methodSyntax, method, env, context.CancellationToken); + + if (stubContext.VtableIndexData.Direction is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) + { + var (_, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(stubContext, VtableIndexStubGeneratorHelpers.GetGeneratorResolver); + foreach (DiagnosticInfo diag in diagnostics) + context.ReportDiagnostic(diag.ToDiagnostic()); + } + + if (stubContext.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) + { + var (_, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(stubContext, VtableIndexStubGeneratorHelpers.GetGeneratorResolver); + foreach (DiagnosticInfo diag in diagnostics) + context.ReportDiagnostic(diag.ToDiagnostic()); + } + } + + internal static DiagnosticInfo? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax methodSyntax, IMethodSymbol method) + { + // Verify the method has no generic types or defined implementation + // and is not marked static or sealed + if (methodSyntax.TypeParameterList is not null + || methodSyntax.Body is not null + || methodSyntax.Modifiers.Any(SyntaxKind.StaticKeyword) + || methodSyntax.Modifiers.Any(SyntaxKind.SealedKeyword)) + { + return DiagnosticInfo.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, methodSyntax.Identifier.GetLocation(), method.Name); + } + + // Verify that the types the method is declared in are marked partial. + for (SyntaxNode? parentNode = methodSyntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent) + { + if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword)) + { + return DiagnosticInfo.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers, methodSyntax.Identifier.GetLocation(), method.Name, typeDecl.Identifier); + } + } + + // Verify the method does not have a ref return + if (method.ReturnsByRef || method.ReturnsByRefReadonly) + { + return DiagnosticInfo.Create(GeneratorDiagnostics.ReturnConfigurationNotSupported, methodSyntax.Identifier.GetLocation(), "ref return", method.ToDisplayString()); + } + + // Verify there is an [UnmanagedObjectUnwrapperAttribute] + if (!method.ContainingType.GetAttributes().Any(att => att.AttributeClass.IsOfType(TypeNames.UnmanagedObjectUnwrapperAttribute))) + { + return DiagnosticInfo.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingUnmanagedObjectUnwrapperAttribute, methodSyntax.Identifier.GetLocation(), method.Name); + } + + return null; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index e10a2ddd2f67dc..45667149844ad1 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -157,8 +157,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) ); }); - context.RegisterDiagnostics(attributedInterfaces.SelectMany(static (data, ct) => data.Diagnostics)); - // Create list of methods (inherited and declared) and their owning interface var interfaceContextsToGenerate = attributedInterfaces.SelectMany(static (a, ct) => a.InterfaceContexts); var comMethodContexts = attributedInterfaces.Select(static (a, ct) => a.MethodContexts); @@ -185,11 +183,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) GenerateIUnknownDerivedAttributeApplication(x.Interface.Info, ct).NormalizeWhitespace() ])); - // Report diagnostics for managed-to-unmanaged and unmanaged-to-managed stubs, deduplicating diagnostics that are reported for both. - context.RegisterDiagnostics( - interfaceAndMethodsContexts - .SelectMany(static (data, ct) => data.DeclaredMethods.SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics).Union(data.DeclaredMethods.SelectMany(m => m.UnmanagedToManagedStub.Diagnostics)))); - var filesToGenerate = syntaxes .Select(static (methodSyntaxes, ct) => { @@ -443,7 +436,7 @@ private static IncrementalMethodStubGenerationContext CalculateSharedStubInforma ComInterfaceDispatchMarshallingInfo.Instance); } - private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax? syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterface, CancellationToken ct) + internal static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax? syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterface, CancellationToken ct) { ISignatureDiagnosticLocations locations = syntax is null ? NoneSignatureDiagnosticLocations.Instance diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs index e46ab33137dab4..a2438c9c583e92 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs @@ -45,20 +45,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Where( static modelData => modelData is not null); - var methodsWithDiagnostics = attributedMethods.Select(static (data, ct) => - { - Diagnostic? diagnostic = GetDiagnosticIfInvalidMethodForGeneration(data.Syntax, data.Symbol); - return new { data.Syntax, data.Symbol, Diagnostic = diagnostic }; - }); - - // Split the methods we want to generate and the ones we don't into two separate groups. - var methodsToGenerate = methodsWithDiagnostics.Where(static data => data.Diagnostic is null); - var invalidMethodDiagnostics = methodsWithDiagnostics.Where(static data => data.Diagnostic is not null); - - context.RegisterSourceOutput(invalidMethodDiagnostics, static (context, invalidMethod) => - { - context.ReportDiagnostic(invalidMethod.Diagnostic); - }); + // Filter out methods that are invalid for generation (diagnostics for invalid methods are reported by the analyzer). + var methodsToGenerate = attributedMethods.Where( + static data => data is not null && VtableIndexStubDiagnosticsAnalyzer.GetDiagnosticIfInvalidMethodForGeneration(data.Syntax, data.Symbol) is null); // Calculate all of information to generate both managed-to-unmanaged and unmanaged-to-managed stubs // for each method. @@ -84,8 +73,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .WithComparer(Comparers.GeneratedSyntax) .WithTrackingName(StepNames.GenerateManagedToNativeStub); - context.RegisterDiagnostics(generateManagedToNativeStub.SelectMany((stubInfo, ct) => stubInfo.Item2)); - context.RegisterConcatenatedSyntaxOutputs(generateManagedToNativeStub.Select((data, ct) => data.Item1), "ManagedToNativeStubs.g.cs"); // Filter the list of all stubs to only the stubs that requested unmanaged-to-managed stub generation. @@ -101,8 +88,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .WithComparer(Comparers.GeneratedSyntax) .WithTrackingName(StepNames.GenerateNativeToManagedStub); - context.RegisterDiagnostics(generateNativeToManagedStub.SelectMany((stubInfo, ct) => stubInfo.Item2)); - context.RegisterConcatenatedSyntaxOutputs(generateNativeToManagedStub.Select((data, ct) => data.Item1), "NativeToManagedStubs.g.cs"); // Generate the native interface metadata for each interface that contains a method with the [VirtualMethodIndex] attribute. @@ -195,7 +180,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }; } - private static SourceAvailableIncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct) + internal static SourceAvailableIncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct) { ct.ThrowIfCancellationRequested(); INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute); @@ -388,42 +373,6 @@ private static (MemberDeclarationSyntax, ImmutableArray) Generat methodStub.Diagnostics.Array.AddRange(diagnostics)); } - private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax methodSyntax, IMethodSymbol method) - { - // Verify the method has no generic types or defined implementation - // and is not marked static or sealed - if (methodSyntax.TypeParameterList is not null - || methodSyntax.Body is not null - || methodSyntax.Modifiers.Any(SyntaxKind.StaticKeyword) - || methodSyntax.Modifiers.Any(SyntaxKind.SealedKeyword)) - { - return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, methodSyntax.Identifier.GetLocation(), method.Name); - } - - // Verify that the types the method is declared in are marked partial. - for (SyntaxNode? parentNode = methodSyntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent) - { - if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword)) - { - return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers, methodSyntax.Identifier.GetLocation(), method.Name, typeDecl.Identifier); - } - } - - // Verify the method does not have a ref return - if (method.ReturnsByRef || method.ReturnsByRefReadonly) - { - return Diagnostic.Create(GeneratorDiagnostics.ReturnConfigurationNotSupported, methodSyntax.Identifier.GetLocation(), "ref return", method.ToDisplayString()); - } - - // Verify there is an [UnmanagedObjectUnwrapperAttribute] - if (!method.ContainingType.GetAttributes().Any(att => att.AttributeClass.IsOfType(TypeNames.UnmanagedObjectUnwrapperAttribute))) - { - return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingUnmanagedObjectUnwrapperAttribute, methodSyntax.Identifier.GetLocation(), method.Name); - } - - return null; - } - private static MemberDeclarationSyntax GenerateNativeInterfaceMetadata(ContainingSyntaxContext context) { return context.WrapMemberInContainingSyntaxWithUnsafeModifier( diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ByValueContentsMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ByValueContentsMarshalling.cs index 3bd2bae74d92d0..880c6a5f9bc7a0 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ByValueContentsMarshalling.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ByValueContentsMarshalling.cs @@ -10,7 +10,7 @@ using Xunit; using static Microsoft.Interop.UnitTests.TestUtils; using StringMarshalling = System.Runtime.InteropServices.StringMarshalling; -using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; +using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; namespace ComInterfaceGenerator.Unit.Tests { diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs index a908d720673dfc..19aefa55d689d2 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs @@ -15,7 +15,7 @@ using Xunit; using static Microsoft.Interop.UnitTests.TestUtils; using StringMarshalling = System.Runtime.InteropServices.StringMarshalling; -using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; +using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; namespace ComInterfaceGenerator.Unit.Tests { From d2f18a0b7107a69f571752dd0716f35443055611 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Feb 2026 23:57:44 +0000 Subject: [PATCH 3/4] Refactor ComInterfaceGeneratorDiagnosticsAnalyzer to use per-interface RegisterSymbolAction Co-authored-by: jkoritzinsky <1571408+jkoritzinsky@users.noreply.github.com> --- ...omInterfaceGeneratorDiagnosticsAnalyzer.cs | 295 +++++++++++------- .../AddMarshalAsToElementTests.cs | 2 +- .../TargetSignatureTests.cs | 2 +- 3 files changed, 186 insertions(+), 113 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComInterfaceGeneratorDiagnosticsAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComInterfaceGeneratorDiagnosticsAnalyzer.cs index 2396ad2c1613c0..b9d36ecacbf52a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComInterfaceGeneratorDiagnosticsAnalyzer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComInterfaceGeneratorDiagnosticsAnalyzer.cs @@ -4,6 +4,7 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; +using System.Diagnostics; using System.Linq; using System.Threading; using Microsoft.CodeAnalysis; @@ -75,7 +76,9 @@ public override void Initialize(AnalysisContext context) compilationContext.Compilation, compilationContext.Compilation.GetEnvironmentFlags()); - var interfaceSymbols = new ConcurrentBag<(InterfaceDeclarationSyntax Syntax, INamedTypeSymbol Symbol)>(); + // Cache ComInterfaceInfo per symbol for deduplication when multiple interfaces share the same base. + // This avoids recomputing the same interface info when traversing the ancestor chain of different derived interfaces. + var interfaceInfoCache = new ConcurrentDictionary>(SymbolEqualityComparer.Default); compilationContext.RegisterSymbolAction(symbolContext => { @@ -83,159 +86,229 @@ public override void Initialize(AnalysisContext context) if (typeSymbol.TypeKind != TypeKind.Interface) return; + // Find the [GeneratedComInterface] attribute and the syntax node of the declaring partial interface + InterfaceDeclarationSyntax? ifaceSyntax = null; foreach (AttributeData attr in typeSymbol.GetAttributes()) { if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttrType)) { - // Use the syntax reference that contains the attribute application - // (important for partial types where the attribute is on a specific partial declaration) - var attrSyntaxRef = attr.ApplicationSyntaxReference; - if (attrSyntaxRef is not null) - { - var attrSyntax = attrSyntaxRef.GetSyntax(symbolContext.CancellationToken); - if (attrSyntax.Parent?.Parent is InterfaceDeclarationSyntax ifaceSyntax) - { - interfaceSymbols.Add((ifaceSyntax, typeSymbol)); - } - } - else - { - foreach (SyntaxReference syntaxRef in typeSymbol.DeclaringSyntaxReferences) - { - if (syntaxRef.GetSyntax(symbolContext.CancellationToken) is InterfaceDeclarationSyntax ifaceSyntax) - { - interfaceSymbols.Add((ifaceSyntax, typeSymbol)); - break; - } - } - } + ifaceSyntax = FindInterfaceSyntaxWithAttribute(typeSymbol, generatedComInterfaceAttrType, symbolContext.CancellationToken); break; } } - }, SymbolKind.NamedType); - compilationContext.RegisterCompilationEndAction(endContext => - { - if (interfaceSymbols.IsEmpty) + if (ifaceSyntax is null) return; - AnalyzeInterfaces(endContext, env, interfaceSymbols.ToImmutableArray()); - }); + AnalyzeInterface(symbolContext, typeSymbol, ifaceSyntax, env, generatedComInterfaceAttrType, interfaceInfoCache); + }, SymbolKind.NamedType); }); } - private static void AnalyzeInterfaces( - CompilationAnalysisContext context, + private static void AnalyzeInterface( + SymbolAnalysisContext context, + INamedTypeSymbol typeSymbol, + InterfaceDeclarationSyntax ifaceSyntax, StubEnvironment env, - ImmutableArray<(InterfaceDeclarationSyntax Syntax, INamedTypeSymbol Symbol)> attributedInterfaces) + INamedTypeSymbol generatedComInterfaceAttrType, + ConcurrentDictionary> interfaceInfoCache) { CancellationToken ct = context.CancellationToken; - // This mirrors the analysis phase of ComInterfaceGenerator.Initialize. - List<(ComInterfaceInfo, INamedTypeSymbol)> interfaceInfos = new(); - HashSet<(ComInterfaceInfo, INamedTypeSymbol)> externalIfaces = new(ComInterfaceInfo.EqualityComparerForExternalIfaces.Instance); - List diags = new(); + // Get or compute ComInterfaceInfo for this interface (cached to avoid recomputing for shared base interfaces) + DiagnosticOr<(ComInterfaceInfo, INamedTypeSymbol)> ciiResult = interfaceInfoCache.GetOrAdd( + typeSymbol, _ => ComInterfaceInfo.From(typeSymbol, ifaceSyntax, env, ct)); - foreach (var (syntax, symbol) in attributedInterfaces) + // Report interface-level diagnostics + if (ciiResult.HasDiagnostic) { - var cii = ComInterfaceInfo.From(symbol, syntax, env, ct); - if (cii.HasDiagnostic) - { - foreach (var diag in cii.Diagnostics) - diags.Add(diag); - } - if (cii.HasValue) - interfaceInfos.Add(cii.Value); - - var externalBase = ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(symbol); - if (!externalBase.IsDefaultOrEmpty) - { - foreach (var b in externalBase) - externalIfaces.Add(b); - } + foreach (DiagnosticInfo diag in ciiResult.Diagnostics) + context.ReportDiagnostic(diag.ToDiagnostic()); } - interfaceInfos.AddRange(externalIfaces); + if (!ciiResult.HasValue) + return; - var comInterfaceContexts = ComInterfaceContext.GetContexts(interfaceInfos.Select(i => i.Item1).ToImmutableArray(), ct); + (ComInterfaceInfo cii, INamedTypeSymbol _) = ciiResult.Value; - Dictionary methodSymbols = new(); - List> methods = new(); + // Build the context chain for this interface (ancestors first, then this interface) to detect + // BaseInterfaceIsNotGenerated. Note: vtable indices don't need to be correct here since we're + // only reporting diagnostics, not emitting code. + ImmutableArray contextChain = BuildContextChain( + typeSymbol, cii, env, generatedComInterfaceAttrType, interfaceInfoCache, ct); - foreach (var cii in interfaceInfos) + ImmutableArray> contextResults = ComInterfaceContext.GetContexts(contextChain, ct); + // BuildContextChain always appends cii as the last element, so contextResults is always non-empty. + Debug.Assert(contextResults.Length > 0); + // The last entry corresponds to this interface + DiagnosticOr thisContextResult = contextResults[contextResults.Length - 1]; + if (thisContextResult.HasDiagnostic) { - var cmi = ComMethodInfo.GetMethodsFromInterface(cii, ct); - var inner = new List(); - foreach (var m in cmi) - { - if (m.HasDiagnostic) - { - foreach (var diag in m.Diagnostics) - diags.Add(diag); - } - if (m.HasValue) - { - inner.Add(m.Value.ComMethod); - methodSymbols.Add(m.Value.ComMethod, m.Value.Symbol); - } - } - methods.Add(inner); + foreach (DiagnosticInfo diag in thisContextResult.Diagnostics) + context.ReportDiagnostic(diag.ToDiagnostic()); + return; } - List<(ComInterfaceContext, SequenceEqualImmutableArray)> ifaceCtxs = new(); - for (int i = 0; i < interfaceInfos.Count; i++) + // Process each method declared on this interface + foreach (DiagnosticOr<(ComMethodInfo ComMethod, IMethodSymbol Symbol)> methodResult in + ComMethodInfo.GetMethodsFromInterface((cii, typeSymbol), ct)) { - var cic = comInterfaceContexts[i]; - if (cic.HasDiagnostic) + if (methodResult.HasDiagnostic) + { + foreach (DiagnosticInfo diag in methodResult.Diagnostics) + context.ReportDiagnostic(diag.ToDiagnostic()); + } + + if (!methodResult.HasValue) + continue; + + (ComMethodInfo comMethod, IMethodSymbol methodSymbol) = methodResult.Value; + + if (comMethod.Syntax is null) + continue; // externally-defined method; no stub diagnostics to report + + // Note: the vtable index passed here (0) doesn't need to be the correct vtable slot since + // we're only reporting diagnostics, not emitting code. + IncrementalMethodStubGenerationContext stubContext = ComInterfaceGenerator.CalculateStubInformation( + comMethod.Syntax, + methodSymbol, + 0, + env, + cii, + ct); + + if (stubContext is not SourceAvailableIncrementalMethodStubGenerationContext srcCtx) + continue; + + ImmutableArray managedToNativeDiags = ImmutableArray.Empty; + ImmutableArray nativeToManagedDiags = ImmutableArray.Empty; + + if (srcCtx.VtableIndexData.Direction is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) { - foreach (var diag in cic.Diagnostics) - diags.Add(diag); + (_, managedToNativeDiags) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(srcCtx, ComInterfaceGeneratorHelpers.GetGeneratorResolver); } - if (cic.HasValue) + if (srcCtx.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) { - ifaceCtxs.Add((cic.Value, methods[i].ToSequenceEqualImmutableArray())); + (_, nativeToManagedDiags) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(srcCtx, ComInterfaceGeneratorHelpers.GetGeneratorResolver); } + + // Deduplicate diagnostics reported for both directions (matching original generator behavior) + foreach (DiagnosticInfo diag in managedToNativeDiags.Union(nativeToManagedDiags)) + context.ReportDiagnostic(diag.ToDiagnostic()); } + } - // Report interface-level and method-level diagnostics - foreach (var diag in diags) - context.ReportDiagnostic(diag.ToDiagnostic()); + /// + /// Builds the ancestor chain for context creation (root-to-parent order, then the current interface last). + /// Only successfully-computed ancestors are included; if an ancestor fails, the chain stops there and + /// will emit + /// for the next derived interface. + /// + private static ImmutableArray BuildContextChain( + INamedTypeSymbol typeSymbol, + ComInterfaceInfo cii, + StubEnvironment env, + INamedTypeSymbol generatedComInterfaceAttrType, + ConcurrentDictionary> interfaceInfoCache, + CancellationToken ct) + { + // For external base interfaces, CreateInterfaceInfoForBaseInterfacesInOtherCompilations already + // provides the full ancestor chain ordered from root to immediate parent. + ImmutableArray<(ComInterfaceInfo, INamedTypeSymbol)> externalBases = + ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(typeSymbol); + if (!externalBases.IsEmpty) + { + return [.. externalBases.Select(static e => e.Item1), cii]; + } - var result = ComMethodContext.CalculateAllMethods(ifaceCtxs, ct); + // Traverse same-compilation base interfaces, inserting at the front to get root-first order. + var ancestorChain = new List(); + INamedTypeSymbol current = typeSymbol; - List methodContexts = new(); - foreach (var data in result) + while (true) { - methodContexts.Add(new ComMethodContext( - data.Method, - data.OwningInterface, - ComInterfaceGenerator.CalculateStubInformation( - data.Method.MethodInfo.Syntax, - methodSymbols[data.Method.MethodInfo], - data.Method.Index, - env, - data.OwningInterface.Info, - ct))); + INamedTypeSymbol? baseSymbol = FindBaseComInterfaceSymbol(current, generatedComInterfaceAttrType); + if (baseSymbol is null) + break; + + if (!SymbolEqualityComparer.Default.Equals(baseSymbol.ContainingAssembly, typeSymbol.ContainingAssembly)) + { + // Switch to external base handling + ImmutableArray<(ComInterfaceInfo, INamedTypeSymbol)> externalInfos = + ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(current); + ancestorChain.InsertRange(0, externalInfos.Select(static e => e.Item1)); + break; + } + + // Get or compute the base's ComInterfaceInfo (using the cache for deduplication) + DiagnosticOr<(ComInterfaceInfo, INamedTypeSymbol)> baseResult = interfaceInfoCache.GetOrAdd( + baseSymbol, + sym => + { + InterfaceDeclarationSyntax? baseSyntax = FindInterfaceSyntaxWithAttribute(sym, generatedComInterfaceAttrType, ct); + if (baseSyntax is null) + return DiagnosticOr<(ComInterfaceInfo, INamedTypeSymbol)>.From( + DiagnosticInfo.Create(GeneratorDiagnostics.CannotAnalyzeInterfacePattern, sym.Locations.FirstOrDefault() ?? Location.None, sym.Name)); + return ComInterfaceInfo.From(sym, baseSyntax, env, ct); + }); + + if (!baseResult.HasValue) + break; // Base failed — GetContexts will report BaseInterfaceIsNotGenerated for this interface + + ancestorChain.Insert(0, baseResult.Value.Item1); + current = baseSymbol; } - // Group method contexts by owning interface to match the generator's GroupComContextsForInterfaceGeneration - // and only report diagnostics for declared (non-inherited) methods. - var groupedByOwningInterface = methodContexts - .GroupBy(m => m.OwningInterface); + ancestorChain.Add(cii); + return ancestorChain.ToImmutableArray(); + } - foreach (var group in groupedByOwningInterface) + /// + /// Finds the first direct base interface of that has the . + /// + private static INamedTypeSymbol? FindBaseComInterfaceSymbol(INamedTypeSymbol typeSymbol, INamedTypeSymbol generatedComInterfaceAttrType) + { + foreach (INamedTypeSymbol iface in typeSymbol.Interfaces) { - var declaredMethods = group.Where(static m => !m.IsInheritedMethod).ToList(); - - // Report diagnostics for managed-to-unmanaged and unmanaged-to-managed stubs, - // deduplicating diagnostics that are reported for both (matching the generator behavior). - var allStubDiags = declaredMethods - .SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics.Array) - .Union(declaredMethods.SelectMany(m => m.UnmanagedToManagedStub.Diagnostics.Array)); + foreach (AttributeData attr in iface.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttrType)) + return iface; + } + } + return null; + } - foreach (var diag in allStubDiags) - context.ReportDiagnostic(diag.ToDiagnostic()); + /// + /// Finds the for that carries the . + /// For partial types, this is the specific partial declaration that has the attribute. + /// + private static InterfaceDeclarationSyntax? FindInterfaceSyntaxWithAttribute( + INamedTypeSymbol symbol, + INamedTypeSymbol generatedComInterfaceAttrType, + CancellationToken ct) + { + foreach (AttributeData attr in symbol.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttrType)) + { + SyntaxReference? attrSyntaxRef = attr.ApplicationSyntaxReference; + if (attrSyntaxRef is not null) + { + SyntaxNode attrSyntax = attrSyntaxRef.GetSyntax(ct); + // Attribute syntax structure: AttributeSyntax -> AttributeListSyntax -> InterfaceDeclarationSyntax + if (attrSyntax.Parent?.Parent is InterfaceDeclarationSyntax ifaceSyntax) + return ifaceSyntax; + } + foreach (SyntaxReference syntaxRef in symbol.DeclaringSyntaxReferences) + { + if (syntaxRef.GetSyntax(ct) is InterfaceDeclarationSyntax ifaceSyntax) + return ifaceSyntax; + } + break; + } } + return null; } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/AddMarshalAsToElementTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/AddMarshalAsToElementTests.cs index d7858a7fc75feb..21c84aef5698ed 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/AddMarshalAsToElementTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/AddMarshalAsToElementTests.cs @@ -10,7 +10,7 @@ using Xunit; using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpCodeFixVerifier< - Microsoft.CodeAnalysis.Testing.EmptyDiagnosticAnalyzer, + Microsoft.Interop.Analyzers.ComInterfaceGeneratorDiagnosticsAnalyzer, Microsoft.Interop.Analyzers.AddMarshalAsToElementFixer>; namespace ComInterfaceGenerator.Unit.Tests diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/TargetSignatureTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/TargetSignatureTests.cs index 2e7bb5815a27fd..e9729c57a075ae 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/TargetSignatureTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/TargetSignatureTests.cs @@ -15,7 +15,7 @@ using Microsoft.Interop; using Xunit; -using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; +using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; namespace ComInterfaceGenerator.Unit.Tests { From ea11306045b616043003e8679337132786fc1f8c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:27:19 +0000 Subject: [PATCH 4/4] Remove dead Diagnostics tuple field and diags collection from ComInterfaceGenerator; simplify VtableIndexStubGenerator stub helpers Co-authored-by: jkoritzinsky <1571408+jkoritzinsky@users.noreply.github.com> --- .../ComInterfaceGenerator.cs | 23 ------------ .../VtableIndexStubGenerator.cs | 36 +++++++++---------- 2 files changed, 16 insertions(+), 43 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 45667149844ad1..4d71dc9d788a99 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -56,7 +56,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { return ( - Diagnostics: ImmutableArray.Empty.ToSequenceEqual(), InterfaceContexts: ImmutableArray.Empty.ToSequenceEqual(), MethodContexts: ImmutableArray.Empty.ToSequenceEqual() ); @@ -64,15 +63,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) StubEnvironment stubEnvironment = input.Right; List<(ComInterfaceInfo, INamedTypeSymbol)> interfaceInfos = new(); HashSet<(ComInterfaceInfo, INamedTypeSymbol)> externalIfaces = new(ComInterfaceInfo.EqualityComparerForExternalIfaces.Instance); - List diags = new(); foreach (var (syntax, symbol) in input.Left) { var cii = ComInterfaceInfo.From(symbol, syntax, stubEnvironment, CancellationToken.None); - if (cii.HasDiagnostic) - { - foreach (var diag in cii.Diagnostics) - diags.Add(diag); - } if (cii.HasValue) interfaceInfos.Add(cii.Value); var externalBase = ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(symbol); @@ -98,13 +91,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var inner = new List(); foreach (var m in cmi) { - if (m.HasDiagnostic) - { - foreach (var diag in m.Diagnostics) - { - diags.Add(diag); - } - } if (m.HasValue) { inner.Add(m.Value.ComMethod); @@ -118,14 +104,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) for (int i = 0; i < interfaceInfos.Count; i++) { var cic = comInterfaceContexts[i]; - var cii = interfaceInfos[i]; - if (cic.HasDiagnostic) - { - foreach (var diag in cic.Diagnostics) - { - diags.Add(diag); - } - } if (cic.HasValue) { ifaceCtxs.Add((cic.Value, methods[i].ToSequenceEqualImmutableArray())); @@ -151,7 +129,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return ( - Diagnostics: diags.ToSequenceEqualImmutableArray(), InterfaceContexts: ifaceCtxs.Select(x => x.Item1).Where(x => !x.IsExternallyDefined).ToSequenceEqualImmutableArray(), MethodContexts: methodContexts.ToSequenceEqualImmutableArray() ); diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs index a2438c9c583e92..4db8815cb94a8e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs @@ -64,31 +64,31 @@ public void Initialize(IncrementalGeneratorInitializationContext context) ) .WithTrackingName(StepNames.CalculateStubInformation); - // Generate the code for the managed-to-unmangaed stubs and the diagnostics from code-generation. - IncrementalValuesProvider<(MemberDeclarationSyntax, ImmutableArray)> generateManagedToNativeStub = generateStubInformation + // Generate the code for the managed-to-unmanaged stubs. + IncrementalValuesProvider generateManagedToNativeStub = generateStubInformation .Where(data => data.VtableIndexData.Direction is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) .Select( static (data, ct) => GenerateManagedToNativeStub(data) ) - .WithComparer(Comparers.GeneratedSyntax) + .WithComparer(SyntaxEquivalentComparer.Instance) .WithTrackingName(StepNames.GenerateManagedToNativeStub); - context.RegisterConcatenatedSyntaxOutputs(generateManagedToNativeStub.Select((data, ct) => data.Item1), "ManagedToNativeStubs.g.cs"); + context.RegisterConcatenatedSyntaxOutputs(generateManagedToNativeStub, "ManagedToNativeStubs.g.cs"); // Filter the list of all stubs to only the stubs that requested unmanaged-to-managed stub generation. IncrementalValuesProvider nativeToManagedStubContexts = generateStubInformation .Where(data => data.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional); - // Generate the code for the unmanaged-to-managed stubs and the diagnostics from code-generation. - IncrementalValuesProvider<(MemberDeclarationSyntax, ImmutableArray)> generateNativeToManagedStub = nativeToManagedStubContexts + // Generate the code for the unmanaged-to-managed stubs. + IncrementalValuesProvider generateNativeToManagedStub = nativeToManagedStubContexts .Select( static (data, ct) => GenerateNativeToManagedStub(data) ) - .WithComparer(Comparers.GeneratedSyntax) + .WithComparer(SyntaxEquivalentComparer.Instance) .WithTrackingName(StepNames.GenerateNativeToManagedStub); - context.RegisterConcatenatedSyntaxOutputs(generateNativeToManagedStub.Select((data, ct) => data.Item1), "NativeToManagedStubs.g.cs"); + context.RegisterConcatenatedSyntaxOutputs(generateNativeToManagedStub, "NativeToManagedStubs.g.cs"); // Generate the native interface metadata for each interface that contains a method with the [VirtualMethodIndex] attribute. IncrementalValuesProvider generateNativeInterface = generateStubInformation @@ -347,30 +347,26 @@ private static MarshallingInfo CreateExceptionMarshallingInfo(AttributeData virt return NoMarshallingInfo.Instance; } - private static (MemberDeclarationSyntax, ImmutableArray) GenerateManagedToNativeStub( + private static MemberDeclarationSyntax GenerateManagedToNativeStub( SourceAvailableIncrementalMethodStubGenerationContext methodStub) { - var (stub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(methodStub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver); + var (stub, _) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(methodStub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver); - return ( - methodStub.ContainingSyntaxContext.AddContainingSyntax( + return methodStub.ContainingSyntaxContext.AddContainingSyntax( NativeTypeContainingSyntax) .WrapMemberInContainingSyntaxWithUnsafeModifier( - stub), - methodStub.Diagnostics.Array.AddRange(diagnostics)); + stub); } - private static (MemberDeclarationSyntax, ImmutableArray) GenerateNativeToManagedStub( + private static MemberDeclarationSyntax GenerateNativeToManagedStub( SourceAvailableIncrementalMethodStubGenerationContext methodStub) { - var (stub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(methodStub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver); + var (stub, _) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(methodStub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver); - return ( - methodStub.ContainingSyntaxContext.AddContainingSyntax( + return methodStub.ContainingSyntaxContext.AddContainingSyntax( NativeTypeContainingSyntax) .WrapMemberInContainingSyntaxWithUnsafeModifier( - stub), - methodStub.Diagnostics.Array.AddRange(diagnostics)); + stub); } private static MemberDeclarationSyntax GenerateNativeInterfaceMetadata(ContainingSyntaxContext context)