From 3cb63841bd251c5e3f103481b5a0545288d92f6f Mon Sep 17 00:00:00 2001 From: DoctorKrolic Date: Sun, 8 Mar 2026 18:23:25 +0300 Subject: [PATCH 1/7] Move disgnostic reporting out from COM class generator --- .../ComClassGeneratorDiagnosticsAnalyzer.cs | 87 +++++++++++++++++++ .../ComClassGenerator.cs | 17 ++-- .../gen/ComInterfaceGenerator/ComClassInfo.cs | 31 ++----- .../DiagnosticOr.cs | 26 ------ ...eneratorInitializationContextExtensions.cs | 20 ----- .../ComClassGeneratorDiagnostics.cs | 2 +- .../ComClassGeneratorOutputShape.cs | 10 +-- 7 files changed, 106 insertions(+), 87 deletions(-) create mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs new file mode 100644 index 00000000000000..1c5c1454876cb0 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs @@ -0,0 +1,87 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; + +namespace Microsoft.Interop.Analyzers; + +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public sealed class ComClassGeneratorDiagnosticsAnalyzer : DiagnosticAnalyzer +{ + public override ImmutableArray SupportedDiagnostics { get; } = + ImmutableArray.Create( + GeneratorDiagnostics.RequiresAllowUnsafeBlocks, + GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier, + GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface); + + public override void Initialize(AnalysisContext context) + { + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); + context.EnableConcurrentExecution(); + + context.RegisterCompilationStartAction(static context => + { + bool unsafeCodeIsEnabled = context.Compilation.Options is CSharpCompilationOptions { AllowUnsafe: true }; + INamedTypeSymbol? generatedComClassAttributeType = context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComClassAttribute); + INamedTypeSymbol? generatedComInterfaceAttributeType = context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute); + + context.RegisterSymbolAction(context => AnalyzeNamedType(context, unsafeCodeIsEnabled, generatedComClassAttributeType, generatedComInterfaceAttributeType), SymbolKind.NamedType); + }); + } + + private static void AnalyzeNamedType(SymbolAnalysisContext context, bool unsafeCodeIsEnabled, INamedTypeSymbol? generatedComClassAttributeType, INamedTypeSymbol? generatedComInterfaceAttributeType) + { + if (context.Symbol is not INamedTypeSymbol { TypeKind: TypeKind.Class } classToAnalyze) + { + return; + } + + if (!classToAnalyze.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, generatedComClassAttributeType))) + { + return; + } + + Location location = classToAnalyze.Locations.First(); + + if (!unsafeCodeIsEnabled) + { + context.ReportDiagnostic( + Diagnostic.Create( + GeneratorDiagnostics.RequiresAllowUnsafeBlocks, + location)); + } + + var declarationNode = (TypeDeclarationSyntax)location.SourceTree.GetRoot().FindNode(location.SourceSpan); + + if (!declarationNode.IsInPartialContext(out _)) + { + context.ReportDiagnostic( + Diagnostic.Create( + GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier, + location, + classToAnalyze)); + } + + foreach (INamedTypeSymbol iface in classToAnalyze.AllInterfaces) + { + if (iface.GetAttributes().FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, generatedComInterfaceAttributeType)) is { } generatedComInterfaceAttribute && + GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute).Options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper)) + { + return; + } + } + + // Class doesn't implement any generated COM interface. Report a warning about that + context.ReportDiagnostic( + Diagnostic.Create( + GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface, + location, + classToAnalyze)); + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs index 1455804b0ba9db..6e5b65dbef2434 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs @@ -20,22 +20,19 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { var unsafeCodeIsEnabled = context.CompilationProvider.Select((comp, ct) => comp.Options is CSharpCompilationOptions { AllowUnsafe: true }); // Unsafe code enabled // Get all types with the [GeneratedComClassAttribute] attribute. - var attributedClassesOrDiagnostics = context.SyntaxProvider + var attributedClasses = context.SyntaxProvider .ForAttributeWithMetadataName( TypeNames.GeneratedComClassAttribute, static (node, ct) => node is ClassDeclarationSyntax, - static (context, ct) => context) - .Combine(unsafeCodeIsEnabled) - .Select(static (data, ct) => + static (context, _) => { - var context = data.Left; - var unsafeCodeIsEnabled = data.Right; var type = (INamedTypeSymbol)context.TargetSymbol; var syntax = (ClassDeclarationSyntax)context.TargetNode; - return ComClassInfo.From(type, syntax, unsafeCodeIsEnabled); - }); - - var attributedClasses = context.FilterAndReportDiagnostics(attributedClassesOrDiagnostics); + return ComClassInfo.TryGetFrom(type, syntax); + }) + .Combine(unsafeCodeIsEnabled) + .Where(static data => data.Left is not null && data.Right) + .Select(static (data, _) => data.Left!); var classInfoType = attributedClasses .Select(static (info, ct) => new ItemAndSyntaxes(info, diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs index 920409ced00f87..02317222eb8f70 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs @@ -23,20 +23,11 @@ private ComClassInfo(string className, ContainingSyntaxContext containingSyntaxC ImplementedInterfacesNames = implementedInterfacesNames; } - public static DiagnosticOr From(INamedTypeSymbol type, ClassDeclarationSyntax syntax, bool unsafeCodeIsEnabled) + public static ComClassInfo? TryGetFrom(INamedTypeSymbol type, ClassDeclarationSyntax syntax) { - if (!unsafeCodeIsEnabled) - { - return DiagnosticOr.From(DiagnosticInfo.Create(GeneratorDiagnostics.RequiresAllowUnsafeBlocks, syntax.Identifier.GetLocation())); - } - if (!syntax.IsInPartialContext(out _)) { - return DiagnosticOr.From( - DiagnosticInfo.Create( - GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier, - syntax.Identifier.GetLocation(), - type.ToDisplayString())); + return null; } ImmutableArray.Builder names = ImmutableArray.CreateBuilder(); @@ -53,19 +44,11 @@ public static DiagnosticOr From(INamedTypeSymbol type, ClassDeclar } } - if (names.Count == 0) - { - return DiagnosticOr.From(DiagnosticInfo.Create(GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface, - syntax.Identifier.GetLocation(), - type.ToDisplayString())); - } - - return DiagnosticOr.From( - new ComClassInfo( - type.ToDisplayString(), - new ContainingSyntaxContext(syntax), - new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList), - new(names.ToImmutable()))); + return names.Count == 0 ? null : new ComClassInfo( + type.ToDisplayString(), + new ContainingSyntaxContext(syntax), + new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList), + new(names.ToImmutable())); } public bool Equals(ComClassInfo? other) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/DiagnosticOr.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/DiagnosticOr.cs index 2fe420651f7338..0c36ae2c82ce83 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/DiagnosticOr.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/DiagnosticOr.cs @@ -2,11 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; -using System.Linq; -using Microsoft.CodeAnalysis; namespace Microsoft.Interop { @@ -88,27 +85,4 @@ public static DiagnosticOr From(T value, params DiagnosticInfo[] diagnostics) return new ValueAndDiagnostic(value, ImmutableArray.Create(diagnostics)); } } - - public static class DiagnosticOrTHelperExtensions - { - /// - /// Splits the elements of into a values provider and a diagnostics provider. - /// - public static (IncrementalValuesProvider, IncrementalValuesProvider) Split(this IncrementalValuesProvider> provider) - { - var values = provider.Where(x => x.HasValue).Select(static (x, ct) => x.Value); - var diagnostics = provider.Where(x => x.HasDiagnostic).SelectMany(static (x, ct) => x.Diagnostics); - return (values, diagnostics); - } - - /// - /// Filters the by whether or not the is a , reports the diagnostics, and returns the values. - /// - public static IncrementalValuesProvider FilterAndReportDiagnostics(this IncrementalGeneratorInitializationContext ctx, IncrementalValuesProvider> diagnosticOrValues) - { - var (values, diagnostics) = diagnosticOrValues.Split(); - ctx.RegisterDiagnostics(diagnostics); - return values; - } - } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs index 88094b0cff8518..5db454c730e068 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs @@ -1,14 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; -using System.Reflection; using System.Text; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.Diagnostics; namespace Microsoft.Interop { @@ -50,22 +46,6 @@ public static IncrementalValueProvider CreateStubEnvironmentPro new StubEnvironment(data.Left, data.Right)); } - public static void RegisterDiagnostics(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider diagnostics) - { - context.RegisterSourceOutput(diagnostics.Where(diag => diag is not null), (context, diagnostic) => - { - context.ReportDiagnostic(diagnostic.ToDiagnostic()); - }); - } - - public static void RegisterDiagnostics(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider diagnostics) - { - context.RegisterSourceOutput(diagnostics.Where(diag => diag is not null), (context, diagnostic) => - { - context.ReportDiagnostic(diagnostic); - }); - } - public static void RegisterConcatenatedSyntaxOutputs(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider nodes, string fileName) where TNode : SyntaxNode { diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs index f07e6d16b6322d..83b3603ffa3936 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs @@ -7,7 +7,7 @@ using Microsoft.CodeAnalysis.Testing; using Microsoft.Interop; using Xunit; -using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; +using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; namespace ComInterfaceGenerator.Unit.Tests { diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs index c61e00b0a3d7e0..1b4dcc94807397 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs @@ -104,7 +104,7 @@ class GeneratedShapeTest : VerifyCS.Test private readonly string[] _typeNames; public GeneratedShapeTest(params string[] typeNames) - :base(referenceAncillaryInterop: false) + : base(referenceAncillaryInterop: false) { _typeNames = typeNames; } @@ -129,11 +129,9 @@ private static void VerifyShape(Compilation comp, string userDefinedClassMetadat userDefinedClass.GetAttributes(), attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, comExposedClassAttribute)); - Assert.Collection(Assert.IsAssignableFrom(iUnknownDerivedAttribute.AttributeClass).TypeArguments, - infoType => - { - Assert.True(Assert.IsAssignableFrom(infoType).IsFileLocal); - }); + Assert.NotNull(iUnknownDerivedAttribute.AttributeClass); + ITypeSymbol typeArgument = Assert.Single(iUnknownDerivedAttribute.AttributeClass.TypeArguments); + Assert.True(Assert.IsType(typeArgument, exactMatch: false).IsFileLocal); } } } From 083004a98d993c0d20da6de5c1a902769baef6d4 Mon Sep 17 00:00:00 2001 From: DoctorKrolic Date: Sun, 8 Mar 2026 18:28:26 +0300 Subject: [PATCH 2/7] Compare symbols instead of relying on `ToString` representation --- .../gen/ComInterfaceGenerator/ComClassGenerator.cs | 3 ++- .../gen/ComInterfaceGenerator/ComClassInfo.cs | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs index 6e5b65dbef2434..ab81a0e152ac44 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs @@ -28,7 +28,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { var type = (INamedTypeSymbol)context.TargetSymbol; var syntax = (ClassDeclarationSyntax)context.TargetNode; - return ComClassInfo.TryGetFrom(type, syntax); + var compilation = context.SemanticModel.Compilation; + return ComClassInfo.TryGetFrom(type, syntax, compilation); }) .Combine(unsafeCodeIsEnabled) .Where(static data => data.Left is not null && data.Right) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs index 02317222eb8f70..8550e2d5bd063a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs @@ -5,6 +5,7 @@ using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; namespace Microsoft.Interop { @@ -23,7 +24,7 @@ private ComClassInfo(string className, ContainingSyntaxContext containingSyntaxC ImplementedInterfacesNames = implementedInterfacesNames; } - public static ComClassInfo? TryGetFrom(INamedTypeSymbol type, ClassDeclarationSyntax syntax) + public static ComClassInfo? TryGetFrom(INamedTypeSymbol type, ClassDeclarationSyntax syntax, Compilation compilation) { if (!syntax.IsInPartialContext(out _)) { @@ -31,9 +32,10 @@ private ComClassInfo(string className, ContainingSyntaxContext containingSyntaxC } ImmutableArray.Builder names = ImmutableArray.CreateBuilder(); + INamedTypeSymbol? generatedComInterfaceAttributeType = compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute); foreach (INamedTypeSymbol iface in type.AllInterfaces) { - AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute); + AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttributeType)); if (generatedComInterfaceAttribute is not null) { var attributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute); From ed0c42d334b0d46ff3e78b9ff0d0ba016ba1d0dd Mon Sep 17 00:00:00 2001 From: DoctorKrolic Date: Sun, 8 Mar 2026 18:49:53 +0300 Subject: [PATCH 3/7] Show that analyzer reports multiple diagnostics at once --- .../ComClassGeneratorDiagnostics.cs | 53 ++++++++++++++++--- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs index 83b3603ffa3936..7820fcd3e899dc 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs @@ -72,7 +72,7 @@ protected override CompilationOptions CreateCompilationOptions() } [Fact] - public async Task UnsafeCodeNotEnabledWarns() + public async Task UnsafeCodeNotEnabledErrors() { string source = """ using System.Runtime.InteropServices; @@ -88,15 +88,54 @@ internal partial interface INativeAPI [GeneratedComClass] internal partial class {|#0:C|} : INativeAPI {} } + """; + + var test = new UnsafeBlocksNotAllowedTest(false) + { + TestCode = source, + ExpectedDiagnostics = + { + new DiagnosticResult(GeneratorDiagnostics.RequiresAllowUnsafeBlocks) + .WithLocation(0) + .WithArguments("Test.C") + } + }; + + await test.RunAsync(); + } + [Fact] + public async Task UnsafeCodeNotEnabledAndNoPartialModifierProducesBothErrors() + { + string source = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + public partial class Test + { + [GeneratedComInterface] + internal partial interface INativeAPI + { + } + + [GeneratedComClass] + internal class {|#0:C|} : INativeAPI {} + } """; - var test = new UnsafeBlocksNotAllowedTest(false); - test.TestState.Sources.Add(source); - test.ExpectedDiagnostics.Add( - new DiagnosticResult(GeneratorDiagnostics.RequiresAllowUnsafeBlocks) - .WithLocation(0) - .WithArguments("Test.C")); + var test = new UnsafeBlocksNotAllowedTest(false) + { + TestCode = source, + ExpectedDiagnostics = + { + new DiagnosticResult(GeneratorDiagnostics.RequiresAllowUnsafeBlocks) + .WithLocation(0) + .WithArguments("Test.C"), + new DiagnosticResult(GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier) + .WithLocation(0) + .WithArguments("Test.C") + } + }; await test.RunAsync(); } From 50897e797cf7ca7c039fa6b95a8899fa67f0977d Mon Sep 17 00:00:00 2001 From: DoctorKrolic Date: Sun, 8 Mar 2026 18:58:04 +0300 Subject: [PATCH 4/7] Avoid reporting a warning if an error is already present --- .../ComClassGeneratorDiagnosticsAnalyzer.cs | 17 +++++- .../ComClassGeneratorDiagnostics.cs | 58 ++++++++++++++++++- 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs index 1c5c1454876cb0..77163de7edfebf 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs @@ -29,7 +29,13 @@ public override void Initialize(AnalysisContext context) { bool unsafeCodeIsEnabled = context.Compilation.Options is CSharpCompilationOptions { AllowUnsafe: true }; INamedTypeSymbol? generatedComClassAttributeType = context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComClassAttribute); - INamedTypeSymbol? generatedComInterfaceAttributeType = context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute); + + // We use this type only to report warning diagnostic. We also don't report a warning if there is at least one error. + // Given that with unsafe code disabled we will get an error on each declaration, we can skip + // unnecessary work of getting this symbol here + INamedTypeSymbol? generatedComInterfaceAttributeType = unsafeCodeIsEnabled + ? context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute) + : null; context.RegisterSymbolAction(context => AnalyzeNamedType(context, unsafeCodeIsEnabled, generatedComClassAttributeType, generatedComInterfaceAttributeType), SymbolKind.NamedType); }); @@ -48,6 +54,7 @@ private static void AnalyzeNamedType(SymbolAnalysisContext context, bool unsafeC } Location location = classToAnalyze.Locations.First(); + bool hasErrors = false; if (!unsafeCodeIsEnabled) { @@ -55,6 +62,7 @@ private static void AnalyzeNamedType(SymbolAnalysisContext context, bool unsafeC Diagnostic.Create( GeneratorDiagnostics.RequiresAllowUnsafeBlocks, location)); + hasErrors = true; } var declarationNode = (TypeDeclarationSyntax)location.SourceTree.GetRoot().FindNode(location.SourceSpan); @@ -66,6 +74,13 @@ private static void AnalyzeNamedType(SymbolAnalysisContext context, bool unsafeC GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier, location, classToAnalyze)); + hasErrors = true; + } + + if (hasErrors) + { + // If we already reported at least one error avoid stacking a warning on top of it + return; } foreach (INamedTypeSymbol iface in classToAnalyze.AllInterfaces) diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs index 7820fcd3e899dc..f058a31cad187f 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs @@ -27,7 +27,6 @@ partial interface INativeAPI [GeneratedComClass] internal class {|#0:C|} : INativeAPI {} - """; await VerifyCS.VerifySourceGeneratorAsync( @@ -54,7 +53,6 @@ partial interface INativeAPI [GeneratedComClass] internal partial class {|#0:C|} : INativeAPI {} } - """; await VerifyCS.VerifySourceGeneratorAsync( @@ -155,7 +153,6 @@ internal interface INativeAPI [GeneratedComClass] internal partial class {|#0:C|} : INativeAPI {} } - """; await VerifyCS.VerifySourceGeneratorAsync( @@ -164,5 +161,60 @@ await VerifyCS.VerifySourceGeneratorAsync( .WithLocation(0) .WithArguments("Test.C")); } + + [Fact] + public async Task NoWarningIfErrorIsAlreadyPresent_UnsafeCodeNotEnabledError() + { + string source = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + public partial class Test{ + internal interface INativeAPI + { + } + + [GeneratedComClass] + internal partial class {|#0:C|} : INativeAPI {} + } + """; + + var test = new UnsafeBlocksNotAllowedTest(false) + { + TestCode = source, + ExpectedDiagnostics = + { + new DiagnosticResult(GeneratorDiagnostics.RequiresAllowUnsafeBlocks) + .WithLocation(0) + .WithArguments("Test.C") + } + }; + + await test.RunAsync(); + } + + [Fact] + public async Task NoWarningIfErrorIsAlreadyPresent_NotPartialContextError() + { + string source = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + public partial class Test{ + internal interface INativeAPI + { + } + + [GeneratedComClass] + internal class {|#0:C|} : INativeAPI {} + } + """; + + await VerifyCS.VerifySourceGeneratorAsync( + source, + new DiagnosticResult(GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier) + .WithLocation(0) + .WithArguments("Test.C")); + } } } From bb21bcb4412736cd7041a093ffec001ec64a2da9 Mon Sep 17 00:00:00 2001 From: DoctorKrolic Date: Mon, 9 Mar 2026 07:53:07 +0300 Subject: [PATCH 5/7] Simplify --- .../gen/ComInterfaceGenerator/ComClassGenerator.cs | 5 +---- .../gen/ComInterfaceGenerator/ComClassInfo.cs | 6 ++++++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs index ab81a0e152ac44..37c246cf74800d 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs @@ -18,7 +18,6 @@ public class ComClassGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { - var unsafeCodeIsEnabled = context.CompilationProvider.Select((comp, ct) => comp.Options is CSharpCompilationOptions { AllowUnsafe: true }); // Unsafe code enabled // Get all types with the [GeneratedComClassAttribute] attribute. var attributedClasses = context.SyntaxProvider .ForAttributeWithMetadataName( @@ -31,9 +30,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var compilation = context.SemanticModel.Compilation; return ComClassInfo.TryGetFrom(type, syntax, compilation); }) - .Combine(unsafeCodeIsEnabled) - .Where(static data => data.Left is not null && data.Right) - .Select(static (data, _) => data.Left!); + .Where(static info => info is not null); var classInfoType = attributedClasses .Select(static (info, ct) => new ItemAndSyntaxes(info, diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs index 8550e2d5bd063a..2b5f5d9e3e6820 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs @@ -4,6 +4,7 @@ using System.Collections.Immutable; using System.Linq; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; @@ -26,6 +27,11 @@ private ComClassInfo(string className, ContainingSyntaxContext containingSyntaxC public static ComClassInfo? TryGetFrom(INamedTypeSymbol type, ClassDeclarationSyntax syntax, Compilation compilation) { + if (compilation.Options is not CSharpCompilationOptions { AllowUnsafe: true }) + { + return null; + } + if (!syntax.IsInPartialContext(out _)) { return null; From 34a5007422995666e67f9b7e6faaf9b9a3ac0da5 Mon Sep 17 00:00:00 2001 From: DoctorKrolic Date: Tue, 10 Mar 2026 19:38:34 +0300 Subject: [PATCH 6/7] Remove redundant verification arguments from diagnostics without any arguments to begin with --- .../ComClassGeneratorDiagnostics.cs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs index f058a31cad187f..c6443b5160553a 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs @@ -95,7 +95,6 @@ internal partial class {|#0:C|} : INativeAPI {} { new DiagnosticResult(GeneratorDiagnostics.RequiresAllowUnsafeBlocks) .WithLocation(0) - .WithArguments("Test.C") } }; @@ -127,8 +126,7 @@ internal class {|#0:C|} : INativeAPI {} ExpectedDiagnostics = { new DiagnosticResult(GeneratorDiagnostics.RequiresAllowUnsafeBlocks) - .WithLocation(0) - .WithArguments("Test.C"), + .WithLocation(0), new DiagnosticResult(GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier) .WithLocation(0) .WithArguments("Test.C") @@ -186,7 +184,6 @@ internal partial class {|#0:C|} : INativeAPI {} { new DiagnosticResult(GeneratorDiagnostics.RequiresAllowUnsafeBlocks) .WithLocation(0) - .WithArguments("Test.C") } }; From 7f3bce9103f8c343e3ecf80b6c74edbdf2796ede Mon Sep 17 00:00:00 2001 From: DoctorKrolic Date: Tue, 10 Mar 2026 19:51:56 +0300 Subject: [PATCH 7/7] Reuse analysis code --- .../ComClassGeneratorDiagnosticsAnalyzer.cs | 40 ++++++++++--------- .../ComClassGenerator.cs | 13 +++++- .../gen/ComInterfaceGenerator/ComClassInfo.cs | 17 +------- 3 files changed, 36 insertions(+), 34 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs index 77163de7edfebf..35952e323a5db2 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using Microsoft.CodeAnalysis; @@ -53,15 +54,20 @@ private static void AnalyzeNamedType(SymbolAnalysisContext context, bool unsafeC return; } - Location location = classToAnalyze.Locations.First(); + foreach (Diagnostic diagnostic in GetDiagnosticsForAnnotatedClass(classToAnalyze, unsafeCodeIsEnabled, generatedComInterfaceAttributeType)) + { + context.ReportDiagnostic(diagnostic); + } + } + + public static IEnumerable GetDiagnosticsForAnnotatedClass(INamedTypeSymbol annotatedClass, bool unsafeCodeIsEnabled, INamedTypeSymbol? generatedComInterfaceAttributeType) + { + Location location = annotatedClass.Locations.First(); bool hasErrors = false; if (!unsafeCodeIsEnabled) { - context.ReportDiagnostic( - Diagnostic.Create( - GeneratorDiagnostics.RequiresAllowUnsafeBlocks, - location)); + yield return Diagnostic.Create(GeneratorDiagnostics.RequiresAllowUnsafeBlocks, location); hasErrors = true; } @@ -69,34 +75,32 @@ private static void AnalyzeNamedType(SymbolAnalysisContext context, bool unsafeC if (!declarationNode.IsInPartialContext(out _)) { - context.ReportDiagnostic( - Diagnostic.Create( - GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier, - location, - classToAnalyze)); + yield return Diagnostic.Create( + GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier, + location, + annotatedClass); hasErrors = true; } if (hasErrors) { // If we already reported at least one error avoid stacking a warning on top of it - return; + yield break; } - foreach (INamedTypeSymbol iface in classToAnalyze.AllInterfaces) + foreach (INamedTypeSymbol iface in annotatedClass.AllInterfaces) { if (iface.GetAttributes().FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, generatedComInterfaceAttributeType)) is { } generatedComInterfaceAttribute && GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute).Options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper)) { - return; + yield break; } } // Class doesn't implement any generated COM interface. Report a warning about that - context.ReportDiagnostic( - Diagnostic.Create( - GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface, - location, - classToAnalyze)); + yield return Diagnostic.Create( + GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface, + location, + annotatedClass); } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs index 37c246cf74800d..4f8940ca0df19f 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs @@ -8,6 +8,8 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; +using Microsoft.Interop.Analyzers; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; using static Microsoft.Interop.SyntaxFactoryExtensions; @@ -28,7 +30,16 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var type = (INamedTypeSymbol)context.TargetSymbol; var syntax = (ClassDeclarationSyntax)context.TargetNode; var compilation = context.SemanticModel.Compilation; - return ComClassInfo.TryGetFrom(type, syntax, compilation); + var unsafeCodeIsEnabled = compilation.Options is CSharpCompilationOptions { AllowUnsafe: true }; + INamedTypeSymbol? generatedComInterfaceAttributeType = compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute); + + // Currently all reported diagnostics are fatal to the generator + if (ComClassGeneratorDiagnosticsAnalyzer.GetDiagnosticsForAnnotatedClass(type, unsafeCodeIsEnabled, generatedComInterfaceAttributeType).Any()) + { + return null; + } + + return ComClassInfo.From(type, syntax, generatedComInterfaceAttributeType); }) .Where(static info => info is not null); diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs index 2b5f5d9e3e6820..838c23c01644bc 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs @@ -4,9 +4,7 @@ using System.Collections.Immutable; using System.Linq; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; namespace Microsoft.Interop { @@ -25,20 +23,9 @@ private ComClassInfo(string className, ContainingSyntaxContext containingSyntaxC ImplementedInterfacesNames = implementedInterfacesNames; } - public static ComClassInfo? TryGetFrom(INamedTypeSymbol type, ClassDeclarationSyntax syntax, Compilation compilation) + public static ComClassInfo From(INamedTypeSymbol type, ClassDeclarationSyntax syntax, INamedTypeSymbol? generatedComInterfaceAttributeType) { - if (compilation.Options is not CSharpCompilationOptions { AllowUnsafe: true }) - { - return null; - } - - if (!syntax.IsInPartialContext(out _)) - { - return null; - } - ImmutableArray.Builder names = ImmutableArray.CreateBuilder(); - INamedTypeSymbol? generatedComInterfaceAttributeType = compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute); foreach (INamedTypeSymbol iface in type.AllInterfaces) { AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttributeType)); @@ -52,7 +39,7 @@ private ComClassInfo(string className, ContainingSyntaxContext containingSyntaxC } } - return names.Count == 0 ? null : new ComClassInfo( + return new ComClassInfo( type.ToDisplayString(), new ContainingSyntaxContext(syntax), new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),