diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs new file mode 100644 index 00000000000000..35952e323a5db2 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComClassGeneratorDiagnosticsAnalyzer.cs @@ -0,0 +1,106 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; + +namespace Microsoft.Interop.Analyzers; + +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public sealed class ComClassGeneratorDiagnosticsAnalyzer : DiagnosticAnalyzer +{ + public override ImmutableArray SupportedDiagnostics { get; } = + ImmutableArray.Create( + GeneratorDiagnostics.RequiresAllowUnsafeBlocks, + GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier, + GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface); + + public override void Initialize(AnalysisContext context) + { + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); + context.EnableConcurrentExecution(); + + context.RegisterCompilationStartAction(static context => + { + bool unsafeCodeIsEnabled = context.Compilation.Options is CSharpCompilationOptions { AllowUnsafe: true }; + INamedTypeSymbol? generatedComClassAttributeType = context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComClassAttribute); + + // We use this type only to report warning diagnostic. We also don't report a warning if there is at least one error. + // Given that with unsafe code disabled we will get an error on each declaration, we can skip + // unnecessary work of getting this symbol here + INamedTypeSymbol? generatedComInterfaceAttributeType = unsafeCodeIsEnabled + ? context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute) + : null; + + context.RegisterSymbolAction(context => AnalyzeNamedType(context, unsafeCodeIsEnabled, generatedComClassAttributeType, generatedComInterfaceAttributeType), SymbolKind.NamedType); + }); + } + + private static void AnalyzeNamedType(SymbolAnalysisContext context, bool unsafeCodeIsEnabled, INamedTypeSymbol? generatedComClassAttributeType, INamedTypeSymbol? generatedComInterfaceAttributeType) + { + if (context.Symbol is not INamedTypeSymbol { TypeKind: TypeKind.Class } classToAnalyze) + { + return; + } + + if (!classToAnalyze.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, generatedComClassAttributeType))) + { + return; + } + + foreach (Diagnostic diagnostic in GetDiagnosticsForAnnotatedClass(classToAnalyze, unsafeCodeIsEnabled, generatedComInterfaceAttributeType)) + { + context.ReportDiagnostic(diagnostic); + } + } + + public static IEnumerable GetDiagnosticsForAnnotatedClass(INamedTypeSymbol annotatedClass, bool unsafeCodeIsEnabled, INamedTypeSymbol? generatedComInterfaceAttributeType) + { + Location location = annotatedClass.Locations.First(); + bool hasErrors = false; + + if (!unsafeCodeIsEnabled) + { + yield return Diagnostic.Create(GeneratorDiagnostics.RequiresAllowUnsafeBlocks, location); + hasErrors = true; + } + + var declarationNode = (TypeDeclarationSyntax)location.SourceTree.GetRoot().FindNode(location.SourceSpan); + + if (!declarationNode.IsInPartialContext(out _)) + { + yield return Diagnostic.Create( + GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier, + location, + annotatedClass); + hasErrors = true; + } + + if (hasErrors) + { + // If we already reported at least one error avoid stacking a warning on top of it + yield break; + } + + foreach (INamedTypeSymbol iface in annotatedClass.AllInterfaces) + { + if (iface.GetAttributes().FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, generatedComInterfaceAttributeType)) is { } generatedComInterfaceAttribute && + GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute).Options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper)) + { + yield break; + } + } + + // Class doesn't implement any generated COM interface. Report a warning about that + yield return Diagnostic.Create( + GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface, + location, + annotatedClass); + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs index 1455804b0ba9db..4f8940ca0df19f 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs @@ -8,6 +8,8 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; +using Microsoft.Interop.Analyzers; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; using static Microsoft.Interop.SyntaxFactoryExtensions; @@ -18,24 +20,28 @@ public class ComClassGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { - var unsafeCodeIsEnabled = context.CompilationProvider.Select((comp, ct) => comp.Options is CSharpCompilationOptions { AllowUnsafe: true }); // Unsafe code enabled // Get all types with the [GeneratedComClassAttribute] attribute. - var attributedClassesOrDiagnostics = context.SyntaxProvider + var attributedClasses = context.SyntaxProvider .ForAttributeWithMetadataName( TypeNames.GeneratedComClassAttribute, static (node, ct) => node is ClassDeclarationSyntax, - static (context, ct) => context) - .Combine(unsafeCodeIsEnabled) - .Select(static (data, ct) => + static (context, _) => { - var context = data.Left; - var unsafeCodeIsEnabled = data.Right; var type = (INamedTypeSymbol)context.TargetSymbol; var syntax = (ClassDeclarationSyntax)context.TargetNode; - return ComClassInfo.From(type, syntax, unsafeCodeIsEnabled); - }); + var compilation = context.SemanticModel.Compilation; + var unsafeCodeIsEnabled = compilation.Options is CSharpCompilationOptions { AllowUnsafe: true }; + INamedTypeSymbol? generatedComInterfaceAttributeType = compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute); - var attributedClasses = context.FilterAndReportDiagnostics(attributedClassesOrDiagnostics); + // Currently all reported diagnostics are fatal to the generator + if (ComClassGeneratorDiagnosticsAnalyzer.GetDiagnosticsForAnnotatedClass(type, unsafeCodeIsEnabled, generatedComInterfaceAttributeType).Any()) + { + return null; + } + + return ComClassInfo.From(type, syntax, generatedComInterfaceAttributeType); + }) + .Where(static info => info is not null); var classInfoType = attributedClasses .Select(static (info, ct) => new ItemAndSyntaxes(info, diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs index 920409ced00f87..838c23c01644bc 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs @@ -23,26 +23,12 @@ private ComClassInfo(string className, ContainingSyntaxContext containingSyntaxC ImplementedInterfacesNames = implementedInterfacesNames; } - public static DiagnosticOr From(INamedTypeSymbol type, ClassDeclarationSyntax syntax, bool unsafeCodeIsEnabled) + public static ComClassInfo From(INamedTypeSymbol type, ClassDeclarationSyntax syntax, INamedTypeSymbol? generatedComInterfaceAttributeType) { - if (!unsafeCodeIsEnabled) - { - return DiagnosticOr.From(DiagnosticInfo.Create(GeneratorDiagnostics.RequiresAllowUnsafeBlocks, syntax.Identifier.GetLocation())); - } - - if (!syntax.IsInPartialContext(out _)) - { - return DiagnosticOr.From( - DiagnosticInfo.Create( - GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier, - syntax.Identifier.GetLocation(), - type.ToDisplayString())); - } - ImmutableArray.Builder names = ImmutableArray.CreateBuilder(); foreach (INamedTypeSymbol iface in type.AllInterfaces) { - AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute); + AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttributeType)); if (generatedComInterfaceAttribute is not null) { var attributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute); @@ -53,19 +39,11 @@ public static DiagnosticOr From(INamedTypeSymbol type, ClassDeclar } } - if (names.Count == 0) - { - return DiagnosticOr.From(DiagnosticInfo.Create(GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface, - syntax.Identifier.GetLocation(), - type.ToDisplayString())); - } - - return DiagnosticOr.From( - new ComClassInfo( - type.ToDisplayString(), - new ContainingSyntaxContext(syntax), - new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList), - new(names.ToImmutable()))); + return new ComClassInfo( + type.ToDisplayString(), + new ContainingSyntaxContext(syntax), + new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList), + new(names.ToImmutable())); } public bool Equals(ComClassInfo? other) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/DiagnosticOr.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/DiagnosticOr.cs index 2fe420651f7338..0c36ae2c82ce83 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/DiagnosticOr.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/DiagnosticOr.cs @@ -2,11 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; -using System.Linq; -using Microsoft.CodeAnalysis; namespace Microsoft.Interop { @@ -88,27 +85,4 @@ public static DiagnosticOr From(T value, params DiagnosticInfo[] diagnostics) return new ValueAndDiagnostic(value, ImmutableArray.Create(diagnostics)); } } - - public static class DiagnosticOrTHelperExtensions - { - /// - /// Splits the elements of into a values provider and a diagnostics provider. - /// - public static (IncrementalValuesProvider, IncrementalValuesProvider) Split(this IncrementalValuesProvider> provider) - { - var values = provider.Where(x => x.HasValue).Select(static (x, ct) => x.Value); - var diagnostics = provider.Where(x => x.HasDiagnostic).SelectMany(static (x, ct) => x.Diagnostics); - return (values, diagnostics); - } - - /// - /// Filters the by whether or not the is a , reports the diagnostics, and returns the values. - /// - public static IncrementalValuesProvider FilterAndReportDiagnostics(this IncrementalGeneratorInitializationContext ctx, IncrementalValuesProvider> diagnosticOrValues) - { - var (values, diagnostics) = diagnosticOrValues.Split(); - ctx.RegisterDiagnostics(diagnostics); - return values; - } - } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs index 88094b0cff8518..5db454c730e068 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs @@ -1,14 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; -using System.Reflection; using System.Text; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.Diagnostics; namespace Microsoft.Interop { @@ -50,22 +46,6 @@ public static IncrementalValueProvider CreateStubEnvironmentPro new StubEnvironment(data.Left, data.Right)); } - public static void RegisterDiagnostics(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider diagnostics) - { - context.RegisterSourceOutput(diagnostics.Where(diag => diag is not null), (context, diagnostic) => - { - context.ReportDiagnostic(diagnostic.ToDiagnostic()); - }); - } - - public static void RegisterDiagnostics(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider diagnostics) - { - context.RegisterSourceOutput(diagnostics.Where(diag => diag is not null), (context, diagnostic) => - { - context.ReportDiagnostic(diagnostic); - }); - } - public static void RegisterConcatenatedSyntaxOutputs(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider nodes, string fileName) where TNode : SyntaxNode { diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs index f07e6d16b6322d..c6443b5160553a 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorDiagnostics.cs @@ -7,7 +7,7 @@ using Microsoft.CodeAnalysis.Testing; using Microsoft.Interop; using Xunit; -using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; +using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; namespace ComInterfaceGenerator.Unit.Tests { @@ -27,7 +27,6 @@ partial interface INativeAPI [GeneratedComClass] internal class {|#0:C|} : INativeAPI {} - """; await VerifyCS.VerifySourceGeneratorAsync( @@ -54,7 +53,6 @@ partial interface INativeAPI [GeneratedComClass] internal partial class {|#0:C|} : INativeAPI {} } - """; await VerifyCS.VerifySourceGeneratorAsync( @@ -72,7 +70,7 @@ protected override CompilationOptions CreateCompilationOptions() } [Fact] - public async Task UnsafeCodeNotEnabledWarns() + public async Task UnsafeCodeNotEnabledErrors() { string source = """ using System.Runtime.InteropServices; @@ -88,15 +86,52 @@ internal partial interface INativeAPI [GeneratedComClass] internal partial class {|#0:C|} : INativeAPI {} } + """; + + var test = new UnsafeBlocksNotAllowedTest(false) + { + TestCode = source, + ExpectedDiagnostics = + { + new DiagnosticResult(GeneratorDiagnostics.RequiresAllowUnsafeBlocks) + .WithLocation(0) + } + }; + + await test.RunAsync(); + } + [Fact] + public async Task UnsafeCodeNotEnabledAndNoPartialModifierProducesBothErrors() + { + string source = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + public partial class Test + { + [GeneratedComInterface] + internal partial interface INativeAPI + { + } + + [GeneratedComClass] + internal class {|#0:C|} : INativeAPI {} + } """; - var test = new UnsafeBlocksNotAllowedTest(false); - test.TestState.Sources.Add(source); - test.ExpectedDiagnostics.Add( - new DiagnosticResult(GeneratorDiagnostics.RequiresAllowUnsafeBlocks) - .WithLocation(0) - .WithArguments("Test.C")); + var test = new UnsafeBlocksNotAllowedTest(false) + { + TestCode = source, + ExpectedDiagnostics = + { + new DiagnosticResult(GeneratorDiagnostics.RequiresAllowUnsafeBlocks) + .WithLocation(0), + new DiagnosticResult(GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier) + .WithLocation(0) + .WithArguments("Test.C") + } + }; await test.RunAsync(); } @@ -116,7 +151,6 @@ internal interface INativeAPI [GeneratedComClass] internal partial class {|#0:C|} : INativeAPI {} } - """; await VerifyCS.VerifySourceGeneratorAsync( @@ -125,5 +159,59 @@ await VerifyCS.VerifySourceGeneratorAsync( .WithLocation(0) .WithArguments("Test.C")); } + + [Fact] + public async Task NoWarningIfErrorIsAlreadyPresent_UnsafeCodeNotEnabledError() + { + string source = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + public partial class Test{ + internal interface INativeAPI + { + } + + [GeneratedComClass] + internal partial class {|#0:C|} : INativeAPI {} + } + """; + + var test = new UnsafeBlocksNotAllowedTest(false) + { + TestCode = source, + ExpectedDiagnostics = + { + new DiagnosticResult(GeneratorDiagnostics.RequiresAllowUnsafeBlocks) + .WithLocation(0) + } + }; + + await test.RunAsync(); + } + + [Fact] + public async Task NoWarningIfErrorIsAlreadyPresent_NotPartialContextError() + { + string source = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + public partial class Test{ + internal interface INativeAPI + { + } + + [GeneratedComClass] + internal class {|#0:C|} : INativeAPI {} + } + """; + + await VerifyCS.VerifySourceGeneratorAsync( + source, + new DiagnosticResult(GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier) + .WithLocation(0) + .WithArguments("Test.C")); + } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs index c61e00b0a3d7e0..1b4dcc94807397 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs @@ -104,7 +104,7 @@ class GeneratedShapeTest : VerifyCS.Test private readonly string[] _typeNames; public GeneratedShapeTest(params string[] typeNames) - :base(referenceAncillaryInterop: false) + : base(referenceAncillaryInterop: false) { _typeNames = typeNames; } @@ -129,11 +129,9 @@ private static void VerifyShape(Compilation comp, string userDefinedClassMetadat userDefinedClass.GetAttributes(), attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, comExposedClassAttribute)); - Assert.Collection(Assert.IsAssignableFrom(iUnknownDerivedAttribute.AttributeClass).TypeArguments, - infoType => - { - Assert.True(Assert.IsAssignableFrom(infoType).IsFileLocal); - }); + Assert.NotNull(iUnknownDerivedAttribute.AttributeClass); + ITypeSymbol typeArgument = Assert.Single(iUnknownDerivedAttribute.AttributeClass.TypeArguments); + Assert.True(Assert.IsType(typeArgument, exactMatch: false).IsFileLocal); } } }