Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;

namespace Microsoft.Interop.Analyzers;

[DiagnosticAnalyzer(LanguageNames.CSharp)]
public sealed class ComClassGeneratorDiagnosticsAnalyzer : DiagnosticAnalyzer
{
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } =
ImmutableArray.Create(
GeneratorDiagnostics.RequiresAllowUnsafeBlocks,
GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier,
GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface);

public override void Initialize(AnalysisContext context)
{
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
context.EnableConcurrentExecution();

context.RegisterCompilationStartAction(static context =>
{
bool unsafeCodeIsEnabled = context.Compilation.Options is CSharpCompilationOptions { AllowUnsafe: true };
INamedTypeSymbol? generatedComClassAttributeType = context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComClassAttribute);

// We use this type only to report warning diagnostic. We also don't report a warning if there is at least one error.
// Given that with unsafe code disabled we will get an error on each declaration, we can skip
// unnecessary work of getting this symbol here
INamedTypeSymbol? generatedComInterfaceAttributeType = unsafeCodeIsEnabled
? context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute)
: null;

context.RegisterSymbolAction(context => AnalyzeNamedType(context, unsafeCodeIsEnabled, generatedComClassAttributeType, generatedComInterfaceAttributeType), SymbolKind.NamedType);
});
}

private static void AnalyzeNamedType(SymbolAnalysisContext context, bool unsafeCodeIsEnabled, INamedTypeSymbol? generatedComClassAttributeType, INamedTypeSymbol? generatedComInterfaceAttributeType)
{
if (context.Symbol is not INamedTypeSymbol { TypeKind: TypeKind.Class } classToAnalyze)
{
return;
}

if (!classToAnalyze.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, generatedComClassAttributeType)))
{
return;
}

foreach (Diagnostic diagnostic in GetDiagnosticsForAnnotatedClass(classToAnalyze, unsafeCodeIsEnabled, generatedComInterfaceAttributeType))
{
context.ReportDiagnostic(diagnostic);
}
}

public static IEnumerable<Diagnostic> GetDiagnosticsForAnnotatedClass(INamedTypeSymbol annotatedClass, bool unsafeCodeIsEnabled, INamedTypeSymbol? generatedComInterfaceAttributeType)
{
Location location = annotatedClass.Locations.First();
bool hasErrors = false;

if (!unsafeCodeIsEnabled)
{
yield return Diagnostic.Create(GeneratorDiagnostics.RequiresAllowUnsafeBlocks, location);
hasErrors = true;
}

var declarationNode = (TypeDeclarationSyntax)location.SourceTree.GetRoot().FindNode(location.SourceSpan);

if (!declarationNode.IsInPartialContext(out _))
{
yield return Diagnostic.Create(
GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier,
location,
annotatedClass);
hasErrors = true;
}

if (hasErrors)
{
// If we already reported at least one error avoid stacking a warning on top of it
yield break;
}

foreach (INamedTypeSymbol iface in annotatedClass.AllInterfaces)
{
if (iface.GetAttributes().FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, generatedComInterfaceAttributeType)) is { } generatedComInterfaceAttribute &&
GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute).Options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper))
{
yield break;
}
}

// Class doesn't implement any generated COM interface. Report a warning about that
yield return Diagnostic.Create(
GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface,
location,
annotatedClass);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;
using Microsoft.Interop.Analyzers;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Microsoft.Interop.SyntaxFactoryExtensions;

Expand All @@ -18,24 +20,28 @@ public class ComClassGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var unsafeCodeIsEnabled = context.CompilationProvider.Select((comp, ct) => comp.Options is CSharpCompilationOptions { AllowUnsafe: true }); // Unsafe code enabled
// Get all types with the [GeneratedComClassAttribute] attribute.
var attributedClassesOrDiagnostics = context.SyntaxProvider
var attributedClasses = context.SyntaxProvider
.ForAttributeWithMetadataName(
TypeNames.GeneratedComClassAttribute,
static (node, ct) => node is ClassDeclarationSyntax,
static (context, ct) => context)
.Combine(unsafeCodeIsEnabled)
.Select(static (data, ct) =>
static (context, _) =>
{
var context = data.Left;
var unsafeCodeIsEnabled = data.Right;
var type = (INamedTypeSymbol)context.TargetSymbol;
var syntax = (ClassDeclarationSyntax)context.TargetNode;
return ComClassInfo.From(type, syntax, unsafeCodeIsEnabled);
});
var compilation = context.SemanticModel.Compilation;
var unsafeCodeIsEnabled = compilation.Options is CSharpCompilationOptions { AllowUnsafe: true };
INamedTypeSymbol? generatedComInterfaceAttributeType = compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute);

var attributedClasses = context.FilterAndReportDiagnostics(attributedClassesOrDiagnostics);
// Currently all reported diagnostics are fatal to the generator
if (ComClassGeneratorDiagnosticsAnalyzer.GetDiagnosticsForAnnotatedClass(type, unsafeCodeIsEnabled, generatedComInterfaceAttributeType).Any())
{
return null;
}

return ComClassInfo.From(type, syntax, generatedComInterfaceAttributeType);
})
.Where(static info => info is not null);

var classInfoType = attributedClasses
.Select(static (info, ct) => new ItemAndSyntaxes<ComClassInfo>(info,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,12 @@ private ComClassInfo(string className, ContainingSyntaxContext containingSyntaxC
ImplementedInterfacesNames = implementedInterfacesNames;
}

public static DiagnosticOr<ComClassInfo> From(INamedTypeSymbol type, ClassDeclarationSyntax syntax, bool unsafeCodeIsEnabled)
public static ComClassInfo From(INamedTypeSymbol type, ClassDeclarationSyntax syntax, INamedTypeSymbol? generatedComInterfaceAttributeType)
{
if (!unsafeCodeIsEnabled)
{
return DiagnosticOr<ComClassInfo>.From(DiagnosticInfo.Create(GeneratorDiagnostics.RequiresAllowUnsafeBlocks, syntax.Identifier.GetLocation()));
}

if (!syntax.IsInPartialContext(out _))
{
return DiagnosticOr<ComClassInfo>.From(
DiagnosticInfo.Create(
GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier,
syntax.Identifier.GetLocation(),
type.ToDisplayString()));
}

ImmutableArray<string>.Builder names = ImmutableArray.CreateBuilder<string>();
foreach (INamedTypeSymbol iface in type.AllInterfaces)
{
AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute);
AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttributeType));
if (generatedComInterfaceAttribute is not null)
{
var attributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute);
Expand All @@ -53,19 +39,11 @@ public static DiagnosticOr<ComClassInfo> From(INamedTypeSymbol type, ClassDeclar
}
}

if (names.Count == 0)
{
return DiagnosticOr<ComClassInfo>.From(DiagnosticInfo.Create(GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface,
syntax.Identifier.GetLocation(),
type.ToDisplayString()));
}

return DiagnosticOr<ComClassInfo>.From(
new ComClassInfo(
type.ToDisplayString(),
new ContainingSyntaxContext(syntax),
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
new(names.ToImmutable())));
return new ComClassInfo(
type.ToDisplayString(),
new ContainingSyntaxContext(syntax),
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
new(names.ToImmutable()));
}

public bool Equals(ComClassInfo? other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using Microsoft.CodeAnalysis;

namespace Microsoft.Interop
{
Expand Down Expand Up @@ -88,27 +85,4 @@ public static DiagnosticOr<T> From(T value, params DiagnosticInfo[] diagnostics)
return new ValueAndDiagnostic(value, ImmutableArray.Create(diagnostics));
}
}

public static class DiagnosticOrTHelperExtensions
{
/// <summary>
/// Splits the elements of <paramref name="provider"/> into a values provider and a diagnostics provider.
/// </summary>
public static (IncrementalValuesProvider<T>, IncrementalValuesProvider<DiagnosticInfo>) Split<T>(this IncrementalValuesProvider<DiagnosticOr<T>> provider)
{
var values = provider.Where(x => x.HasValue).Select(static (x, ct) => x.Value);
var diagnostics = provider.Where(x => x.HasDiagnostic).SelectMany(static (x, ct) => x.Diagnostics);
return (values, diagnostics);
}

/// <summary>
/// Filters the <see cref="IncrementalValuesProvider{TValue}"/> by whether or not the is a <see cref="Diagnostic"/>, reports the diagnostics, and returns the values.
/// </summary>
public static IncrementalValuesProvider<T> FilterAndReportDiagnostics<T>(this IncrementalGeneratorInitializationContext ctx, IncrementalValuesProvider<DiagnosticOr<T>> diagnosticOrValues)
{
var (values, diagnostics) = diagnosticOrValues.Split();
ctx.RegisterDiagnostics(diagnostics);
return values;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Reflection;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;

namespace Microsoft.Interop
{
Expand Down Expand Up @@ -50,22 +46,6 @@ public static IncrementalValueProvider<StubEnvironment> CreateStubEnvironmentPro
new StubEnvironment(data.Left, data.Right));
}

public static void RegisterDiagnostics(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<DiagnosticInfo> diagnostics)
{
context.RegisterSourceOutput(diagnostics.Where(diag => diag is not null), (context, diagnostic) =>
{
context.ReportDiagnostic(diagnostic.ToDiagnostic());
});
}

public static void RegisterDiagnostics(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<Diagnostic> diagnostics)
{
context.RegisterSourceOutput(diagnostics.Where(diag => diag is not null), (context, diagnostic) =>
{
context.ReportDiagnostic(diagnostic);
});
}

public static void RegisterConcatenatedSyntaxOutputs<TNode>(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<TNode> nodes, string fileName)
where TNode : SyntaxNode
{
Expand Down
Loading
Loading