diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComInterfaceGeneratorDiagnosticsAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComInterfaceGeneratorDiagnosticsAnalyzer.cs new file mode 100644 index 00000000000000..b9d36ecacbf52a --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ComInterfaceGeneratorDiagnosticsAnalyzer.cs @@ -0,0 +1,314 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; + +namespace Microsoft.Interop.Analyzers +{ + [DiagnosticAnalyzer(LanguageNames.CSharp)] + public class ComInterfaceGeneratorDiagnosticsAnalyzer : DiagnosticAnalyzer + { + public override ImmutableArray SupportedDiagnostics { get; } = + ImmutableArray.Create( + // Interface-level diagnostics + GeneratorDiagnostics.RequiresAllowUnsafeBlocks, + GeneratorDiagnostics.InvalidAttributedInterfaceGenericNotSupported, + GeneratorDiagnostics.InvalidAttributedInterfaceMissingPartialModifiers, + GeneratorDiagnostics.InvalidAttributedInterfaceNotAccessible, + GeneratorDiagnostics.InvalidAttributedInterfaceMissingGuidAttribute, + GeneratorDiagnostics.InvalidStringMarshallingMismatchBetweenBaseAndDerived, + GeneratorDiagnostics.InvalidOptionsOnInterface, + GeneratorDiagnostics.InvalidStringMarshallingConfigurationOnInterface, + GeneratorDiagnostics.InvalidExceptionToUnmanagedMarshallerType, + GeneratorDiagnostics.StringMarshallingCustomTypeNotAccessibleByGeneratedCode, + GeneratorDiagnostics.ExceptionToUnmanagedMarshallerNotAccessibleByGeneratedCode, + GeneratorDiagnostics.MultipleComInterfaceBaseTypes, + GeneratorDiagnostics.BaseInterfaceIsNotGenerated, + GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly, + // Method-level diagnostics + GeneratorDiagnostics.MethodNotDeclaredInAttributedInterface, + GeneratorDiagnostics.InstancePropertyDeclaredInInterface, + GeneratorDiagnostics.InstanceEventDeclaredInInterface, + GeneratorDiagnostics.CannotAnalyzeMethodPattern, + GeneratorDiagnostics.CannotAnalyzeInterfacePattern, + // Stub-level diagnostics + GeneratorDiagnostics.ConfigurationNotSupported, + GeneratorDiagnostics.InvalidStringMarshallingConfigurationOnMethod, + GeneratorDiagnostics.ParameterTypeNotSupported, + GeneratorDiagnostics.ReturnTypeNotSupported, + GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails, + GeneratorDiagnostics.ReturnTypeNotSupportedWithDetails, + GeneratorDiagnostics.ParameterConfigurationNotSupported, + GeneratorDiagnostics.ReturnConfigurationNotSupported, + GeneratorDiagnostics.MarshalAsParameterConfigurationNotSupported, + GeneratorDiagnostics.MarshalAsReturnConfigurationNotSupported, + GeneratorDiagnostics.ConfigurationValueNotSupported, + GeneratorDiagnostics.MarshallingAttributeConfigurationNotSupported, + GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo, + GeneratorDiagnostics.UnnecessaryReturnMarshallingInfo, + GeneratorDiagnostics.ComMethodManagedReturnWillBeOutVariable, + GeneratorDiagnostics.HResultTypeWillBeTreatedAsStruct, + GeneratorDiagnostics.SizeOfInCollectionMustBeDefinedAtCallOutParam, + GeneratorDiagnostics.SizeOfInCollectionMustBeDefinedAtCallReturnValue, + GeneratorDiagnostics.InvalidExceptionMarshallingConfiguration, + GeneratorDiagnostics.GeneratedComInterfaceUsageDoesNotFollowBestPractices); + + public override void Initialize(AnalysisContext context) + { + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); + context.EnableConcurrentExecution(); + context.RegisterCompilationStartAction(compilationContext => + { + INamedTypeSymbol? generatedComInterfaceAttrType = compilationContext.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute); + if (generatedComInterfaceAttrType is null) + return; + + StubEnvironment env = new StubEnvironment( + compilationContext.Compilation, + compilationContext.Compilation.GetEnvironmentFlags()); + + // Cache ComInterfaceInfo per symbol for deduplication when multiple interfaces share the same base. + // This avoids recomputing the same interface info when traversing the ancestor chain of different derived interfaces. + var interfaceInfoCache = new ConcurrentDictionary>(SymbolEqualityComparer.Default); + + compilationContext.RegisterSymbolAction(symbolContext => + { + INamedTypeSymbol typeSymbol = (INamedTypeSymbol)symbolContext.Symbol; + if (typeSymbol.TypeKind != TypeKind.Interface) + return; + + // Find the [GeneratedComInterface] attribute and the syntax node of the declaring partial interface + InterfaceDeclarationSyntax? ifaceSyntax = null; + foreach (AttributeData attr in typeSymbol.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttrType)) + { + ifaceSyntax = FindInterfaceSyntaxWithAttribute(typeSymbol, generatedComInterfaceAttrType, symbolContext.CancellationToken); + break; + } + } + + if (ifaceSyntax is null) + return; + + AnalyzeInterface(symbolContext, typeSymbol, ifaceSyntax, env, generatedComInterfaceAttrType, interfaceInfoCache); + }, SymbolKind.NamedType); + }); + } + + private static void AnalyzeInterface( + SymbolAnalysisContext context, + INamedTypeSymbol typeSymbol, + InterfaceDeclarationSyntax ifaceSyntax, + StubEnvironment env, + INamedTypeSymbol generatedComInterfaceAttrType, + ConcurrentDictionary> interfaceInfoCache) + { + CancellationToken ct = context.CancellationToken; + + // Get or compute ComInterfaceInfo for this interface (cached to avoid recomputing for shared base interfaces) + DiagnosticOr<(ComInterfaceInfo, INamedTypeSymbol)> ciiResult = interfaceInfoCache.GetOrAdd( + typeSymbol, _ => ComInterfaceInfo.From(typeSymbol, ifaceSyntax, env, ct)); + + // Report interface-level diagnostics + if (ciiResult.HasDiagnostic) + { + foreach (DiagnosticInfo diag in ciiResult.Diagnostics) + context.ReportDiagnostic(diag.ToDiagnostic()); + } + + if (!ciiResult.HasValue) + return; + + (ComInterfaceInfo cii, INamedTypeSymbol _) = ciiResult.Value; + + // Build the context chain for this interface (ancestors first, then this interface) to detect + // BaseInterfaceIsNotGenerated. Note: vtable indices don't need to be correct here since we're + // only reporting diagnostics, not emitting code. + ImmutableArray contextChain = BuildContextChain( + typeSymbol, cii, env, generatedComInterfaceAttrType, interfaceInfoCache, ct); + + ImmutableArray> contextResults = ComInterfaceContext.GetContexts(contextChain, ct); + // BuildContextChain always appends cii as the last element, so contextResults is always non-empty. + Debug.Assert(contextResults.Length > 0); + // The last entry corresponds to this interface + DiagnosticOr thisContextResult = contextResults[contextResults.Length - 1]; + if (thisContextResult.HasDiagnostic) + { + foreach (DiagnosticInfo diag in thisContextResult.Diagnostics) + context.ReportDiagnostic(diag.ToDiagnostic()); + return; + } + + // Process each method declared on this interface + foreach (DiagnosticOr<(ComMethodInfo ComMethod, IMethodSymbol Symbol)> methodResult in + ComMethodInfo.GetMethodsFromInterface((cii, typeSymbol), ct)) + { + if (methodResult.HasDiagnostic) + { + foreach (DiagnosticInfo diag in methodResult.Diagnostics) + context.ReportDiagnostic(diag.ToDiagnostic()); + } + + if (!methodResult.HasValue) + continue; + + (ComMethodInfo comMethod, IMethodSymbol methodSymbol) = methodResult.Value; + + if (comMethod.Syntax is null) + continue; // externally-defined method; no stub diagnostics to report + + // Note: the vtable index passed here (0) doesn't need to be the correct vtable slot since + // we're only reporting diagnostics, not emitting code. + IncrementalMethodStubGenerationContext stubContext = ComInterfaceGenerator.CalculateStubInformation( + comMethod.Syntax, + methodSymbol, + 0, + env, + cii, + ct); + + if (stubContext is not SourceAvailableIncrementalMethodStubGenerationContext srcCtx) + continue; + + ImmutableArray managedToNativeDiags = ImmutableArray.Empty; + ImmutableArray nativeToManagedDiags = ImmutableArray.Empty; + + if (srcCtx.VtableIndexData.Direction is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) + { + (_, managedToNativeDiags) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(srcCtx, ComInterfaceGeneratorHelpers.GetGeneratorResolver); + } + if (srcCtx.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) + { + (_, nativeToManagedDiags) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(srcCtx, ComInterfaceGeneratorHelpers.GetGeneratorResolver); + } + + // Deduplicate diagnostics reported for both directions (matching original generator behavior) + foreach (DiagnosticInfo diag in managedToNativeDiags.Union(nativeToManagedDiags)) + context.ReportDiagnostic(diag.ToDiagnostic()); + } + } + + /// + /// Builds the ancestor chain for context creation (root-to-parent order, then the current interface last). + /// Only successfully-computed ancestors are included; if an ancestor fails, the chain stops there and + /// will emit + /// for the next derived interface. + /// + private static ImmutableArray BuildContextChain( + INamedTypeSymbol typeSymbol, + ComInterfaceInfo cii, + StubEnvironment env, + INamedTypeSymbol generatedComInterfaceAttrType, + ConcurrentDictionary> interfaceInfoCache, + CancellationToken ct) + { + // For external base interfaces, CreateInterfaceInfoForBaseInterfacesInOtherCompilations already + // provides the full ancestor chain ordered from root to immediate parent. + ImmutableArray<(ComInterfaceInfo, INamedTypeSymbol)> externalBases = + ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(typeSymbol); + if (!externalBases.IsEmpty) + { + return [.. externalBases.Select(static e => e.Item1), cii]; + } + + // Traverse same-compilation base interfaces, inserting at the front to get root-first order. + var ancestorChain = new List(); + INamedTypeSymbol current = typeSymbol; + + while (true) + { + INamedTypeSymbol? baseSymbol = FindBaseComInterfaceSymbol(current, generatedComInterfaceAttrType); + if (baseSymbol is null) + break; + + if (!SymbolEqualityComparer.Default.Equals(baseSymbol.ContainingAssembly, typeSymbol.ContainingAssembly)) + { + // Switch to external base handling + ImmutableArray<(ComInterfaceInfo, INamedTypeSymbol)> externalInfos = + ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(current); + ancestorChain.InsertRange(0, externalInfos.Select(static e => e.Item1)); + break; + } + + // Get or compute the base's ComInterfaceInfo (using the cache for deduplication) + DiagnosticOr<(ComInterfaceInfo, INamedTypeSymbol)> baseResult = interfaceInfoCache.GetOrAdd( + baseSymbol, + sym => + { + InterfaceDeclarationSyntax? baseSyntax = FindInterfaceSyntaxWithAttribute(sym, generatedComInterfaceAttrType, ct); + if (baseSyntax is null) + return DiagnosticOr<(ComInterfaceInfo, INamedTypeSymbol)>.From( + DiagnosticInfo.Create(GeneratorDiagnostics.CannotAnalyzeInterfacePattern, sym.Locations.FirstOrDefault() ?? Location.None, sym.Name)); + return ComInterfaceInfo.From(sym, baseSyntax, env, ct); + }); + + if (!baseResult.HasValue) + break; // Base failed — GetContexts will report BaseInterfaceIsNotGenerated for this interface + + ancestorChain.Insert(0, baseResult.Value.Item1); + current = baseSymbol; + } + + ancestorChain.Add(cii); + return ancestorChain.ToImmutableArray(); + } + + /// + /// Finds the first direct base interface of that has the . + /// + private static INamedTypeSymbol? FindBaseComInterfaceSymbol(INamedTypeSymbol typeSymbol, INamedTypeSymbol generatedComInterfaceAttrType) + { + foreach (INamedTypeSymbol iface in typeSymbol.Interfaces) + { + foreach (AttributeData attr in iface.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttrType)) + return iface; + } + } + return null; + } + + /// + /// Finds the for that carries the . + /// For partial types, this is the specific partial declaration that has the attribute. + /// + private static InterfaceDeclarationSyntax? FindInterfaceSyntaxWithAttribute( + INamedTypeSymbol symbol, + INamedTypeSymbol generatedComInterfaceAttrType, + CancellationToken ct) + { + foreach (AttributeData attr in symbol.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttrType)) + { + SyntaxReference? attrSyntaxRef = attr.ApplicationSyntaxReference; + if (attrSyntaxRef is not null) + { + SyntaxNode attrSyntax = attrSyntaxRef.GetSyntax(ct); + // Attribute syntax structure: AttributeSyntax -> AttributeListSyntax -> InterfaceDeclarationSyntax + if (attrSyntax.Parent?.Parent is InterfaceDeclarationSyntax ifaceSyntax) + return ifaceSyntax; + } + foreach (SyntaxReference syntaxRef in symbol.DeclaringSyntaxReferences) + { + if (syntaxRef.GetSyntax(ct) is InterfaceDeclarationSyntax ifaceSyntax) + return ifaceSyntax; + } + break; + } + } + return null; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/VtableIndexStubDiagnosticsAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/VtableIndexStubDiagnosticsAnalyzer.cs new file mode 100644 index 00000000000000..01af871b46328e --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/VtableIndexStubDiagnosticsAnalyzer.cs @@ -0,0 +1,143 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; + +namespace Microsoft.Interop.Analyzers +{ + [DiagnosticAnalyzer(LanguageNames.CSharp)] + public class VtableIndexStubDiagnosticsAnalyzer : DiagnosticAnalyzer + { + public override ImmutableArray SupportedDiagnostics { get; } = + ImmutableArray.Create( + GeneratorDiagnostics.InvalidAttributedMethodSignature, + GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers, + GeneratorDiagnostics.ReturnConfigurationNotSupported, + GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingUnmanagedObjectUnwrapperAttribute, + GeneratorDiagnostics.InvalidStringMarshallingConfigurationOnMethod, + GeneratorDiagnostics.InvalidExceptionMarshallingConfiguration, + GeneratorDiagnostics.ConfigurationNotSupported, + GeneratorDiagnostics.ParameterTypeNotSupported, + GeneratorDiagnostics.ReturnTypeNotSupported, + GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails, + GeneratorDiagnostics.ReturnTypeNotSupportedWithDetails, + GeneratorDiagnostics.ParameterConfigurationNotSupported, + GeneratorDiagnostics.MarshalAsParameterConfigurationNotSupported, + GeneratorDiagnostics.MarshalAsReturnConfigurationNotSupported, + GeneratorDiagnostics.ConfigurationValueNotSupported, + GeneratorDiagnostics.MarshallingAttributeConfigurationNotSupported, + GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo, + GeneratorDiagnostics.UnnecessaryReturnMarshallingInfo, + GeneratorDiagnostics.GeneratedComInterfaceUsageDoesNotFollowBestPractices); + + public override void Initialize(AnalysisContext context) + { + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); + context.EnableConcurrentExecution(); + context.RegisterCompilationStartAction(compilationContext => + { + INamedTypeSymbol? virtualMethodIndexAttrType = compilationContext.Compilation.GetBestTypeByMetadataName(TypeNames.VirtualMethodIndexAttribute); + if (virtualMethodIndexAttrType is null) + return; + + StubEnvironment env = new StubEnvironment( + compilationContext.Compilation, + compilationContext.Compilation.GetEnvironmentFlags()); + + compilationContext.RegisterSymbolAction(symbolContext => + { + IMethodSymbol method = (IMethodSymbol)symbolContext.Symbol; + AttributeData? virtualMethodIndexAttr = null; + foreach (AttributeData attr in method.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, virtualMethodIndexAttrType)) + { + virtualMethodIndexAttr = attr; + break; + } + } + + if (virtualMethodIndexAttr is null) + return; + + foreach (SyntaxReference syntaxRef in method.DeclaringSyntaxReferences) + { + if (syntaxRef.GetSyntax(symbolContext.CancellationToken) is MethodDeclarationSyntax methodSyntax) + { + AnalyzeMethod(symbolContext, methodSyntax, method, env); + break; + } + } + }, SymbolKind.Method); + }); + } + + private static void AnalyzeMethod(SymbolAnalysisContext context, MethodDeclarationSyntax methodSyntax, IMethodSymbol method, StubEnvironment env) + { + DiagnosticInfo? invalidMethodDiagnostic = GetDiagnosticIfInvalidMethodForGeneration(methodSyntax, method); + if (invalidMethodDiagnostic is not null) + { + context.ReportDiagnostic(invalidMethodDiagnostic.ToDiagnostic()); + return; + } + + SourceAvailableIncrementalMethodStubGenerationContext stubContext = VtableIndexStubGenerator.CalculateStubInformation(methodSyntax, method, env, context.CancellationToken); + + if (stubContext.VtableIndexData.Direction is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) + { + var (_, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(stubContext, VtableIndexStubGeneratorHelpers.GetGeneratorResolver); + foreach (DiagnosticInfo diag in diagnostics) + context.ReportDiagnostic(diag.ToDiagnostic()); + } + + if (stubContext.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) + { + var (_, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(stubContext, VtableIndexStubGeneratorHelpers.GetGeneratorResolver); + foreach (DiagnosticInfo diag in diagnostics) + context.ReportDiagnostic(diag.ToDiagnostic()); + } + } + + internal static DiagnosticInfo? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax methodSyntax, IMethodSymbol method) + { + // Verify the method has no generic types or defined implementation + // and is not marked static or sealed + if (methodSyntax.TypeParameterList is not null + || methodSyntax.Body is not null + || methodSyntax.Modifiers.Any(SyntaxKind.StaticKeyword) + || methodSyntax.Modifiers.Any(SyntaxKind.SealedKeyword)) + { + return DiagnosticInfo.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, methodSyntax.Identifier.GetLocation(), method.Name); + } + + // Verify that the types the method is declared in are marked partial. + for (SyntaxNode? parentNode = methodSyntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent) + { + if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword)) + { + return DiagnosticInfo.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers, methodSyntax.Identifier.GetLocation(), method.Name, typeDecl.Identifier); + } + } + + // Verify the method does not have a ref return + if (method.ReturnsByRef || method.ReturnsByRefReadonly) + { + return DiagnosticInfo.Create(GeneratorDiagnostics.ReturnConfigurationNotSupported, methodSyntax.Identifier.GetLocation(), "ref return", method.ToDisplayString()); + } + + // Verify there is an [UnmanagedObjectUnwrapperAttribute] + if (!method.ContainingType.GetAttributes().Any(att => att.AttributeClass.IsOfType(TypeNames.UnmanagedObjectUnwrapperAttribute))) + { + return DiagnosticInfo.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingUnmanagedObjectUnwrapperAttribute, methodSyntax.Identifier.GetLocation(), method.Name); + } + + return null; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index e10a2ddd2f67dc..4d71dc9d788a99 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -56,7 +56,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { return ( - Diagnostics: ImmutableArray.Empty.ToSequenceEqual(), InterfaceContexts: ImmutableArray.Empty.ToSequenceEqual(), MethodContexts: ImmutableArray.Empty.ToSequenceEqual() ); @@ -64,15 +63,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) StubEnvironment stubEnvironment = input.Right; List<(ComInterfaceInfo, INamedTypeSymbol)> interfaceInfos = new(); HashSet<(ComInterfaceInfo, INamedTypeSymbol)> externalIfaces = new(ComInterfaceInfo.EqualityComparerForExternalIfaces.Instance); - List diags = new(); foreach (var (syntax, symbol) in input.Left) { var cii = ComInterfaceInfo.From(symbol, syntax, stubEnvironment, CancellationToken.None); - if (cii.HasDiagnostic) - { - foreach (var diag in cii.Diagnostics) - diags.Add(diag); - } if (cii.HasValue) interfaceInfos.Add(cii.Value); var externalBase = ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(symbol); @@ -98,13 +91,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var inner = new List(); foreach (var m in cmi) { - if (m.HasDiagnostic) - { - foreach (var diag in m.Diagnostics) - { - diags.Add(diag); - } - } if (m.HasValue) { inner.Add(m.Value.ComMethod); @@ -118,14 +104,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) for (int i = 0; i < interfaceInfos.Count; i++) { var cic = comInterfaceContexts[i]; - var cii = interfaceInfos[i]; - if (cic.HasDiagnostic) - { - foreach (var diag in cic.Diagnostics) - { - diags.Add(diag); - } - } if (cic.HasValue) { ifaceCtxs.Add((cic.Value, methods[i].ToSequenceEqualImmutableArray())); @@ -151,14 +129,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return ( - Diagnostics: diags.ToSequenceEqualImmutableArray(), InterfaceContexts: ifaceCtxs.Select(x => x.Item1).Where(x => !x.IsExternallyDefined).ToSequenceEqualImmutableArray(), MethodContexts: methodContexts.ToSequenceEqualImmutableArray() ); }); - context.RegisterDiagnostics(attributedInterfaces.SelectMany(static (data, ct) => data.Diagnostics)); - // Create list of methods (inherited and declared) and their owning interface var interfaceContextsToGenerate = attributedInterfaces.SelectMany(static (a, ct) => a.InterfaceContexts); var comMethodContexts = attributedInterfaces.Select(static (a, ct) => a.MethodContexts); @@ -185,11 +160,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) GenerateIUnknownDerivedAttributeApplication(x.Interface.Info, ct).NormalizeWhitespace() ])); - // Report diagnostics for managed-to-unmanaged and unmanaged-to-managed stubs, deduplicating diagnostics that are reported for both. - context.RegisterDiagnostics( - interfaceAndMethodsContexts - .SelectMany(static (data, ct) => data.DeclaredMethods.SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics).Union(data.DeclaredMethods.SelectMany(m => m.UnmanagedToManagedStub.Diagnostics)))); - var filesToGenerate = syntaxes .Select(static (methodSyntaxes, ct) => { @@ -443,7 +413,7 @@ private static IncrementalMethodStubGenerationContext CalculateSharedStubInforma ComInterfaceDispatchMarshallingInfo.Instance); } - private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax? syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterface, CancellationToken ct) + internal static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax? syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterface, CancellationToken ct) { ISignatureDiagnosticLocations locations = syntax is null ? NoneSignatureDiagnosticLocations.Instance diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs index e46ab33137dab4..4db8815cb94a8e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs @@ -45,20 +45,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Where( static modelData => modelData is not null); - var methodsWithDiagnostics = attributedMethods.Select(static (data, ct) => - { - Diagnostic? diagnostic = GetDiagnosticIfInvalidMethodForGeneration(data.Syntax, data.Symbol); - return new { data.Syntax, data.Symbol, Diagnostic = diagnostic }; - }); - - // Split the methods we want to generate and the ones we don't into two separate groups. - var methodsToGenerate = methodsWithDiagnostics.Where(static data => data.Diagnostic is null); - var invalidMethodDiagnostics = methodsWithDiagnostics.Where(static data => data.Diagnostic is not null); - - context.RegisterSourceOutput(invalidMethodDiagnostics, static (context, invalidMethod) => - { - context.ReportDiagnostic(invalidMethod.Diagnostic); - }); + // Filter out methods that are invalid for generation (diagnostics for invalid methods are reported by the analyzer). + var methodsToGenerate = attributedMethods.Where( + static data => data is not null && VtableIndexStubDiagnosticsAnalyzer.GetDiagnosticIfInvalidMethodForGeneration(data.Syntax, data.Symbol) is null); // Calculate all of information to generate both managed-to-unmanaged and unmanaged-to-managed stubs // for each method. @@ -75,35 +64,31 @@ public void Initialize(IncrementalGeneratorInitializationContext context) ) .WithTrackingName(StepNames.CalculateStubInformation); - // Generate the code for the managed-to-unmangaed stubs and the diagnostics from code-generation. - IncrementalValuesProvider<(MemberDeclarationSyntax, ImmutableArray)> generateManagedToNativeStub = generateStubInformation + // Generate the code for the managed-to-unmanaged stubs. + IncrementalValuesProvider generateManagedToNativeStub = generateStubInformation .Where(data => data.VtableIndexData.Direction is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) .Select( static (data, ct) => GenerateManagedToNativeStub(data) ) - .WithComparer(Comparers.GeneratedSyntax) + .WithComparer(SyntaxEquivalentComparer.Instance) .WithTrackingName(StepNames.GenerateManagedToNativeStub); - context.RegisterDiagnostics(generateManagedToNativeStub.SelectMany((stubInfo, ct) => stubInfo.Item2)); - - context.RegisterConcatenatedSyntaxOutputs(generateManagedToNativeStub.Select((data, ct) => data.Item1), "ManagedToNativeStubs.g.cs"); + context.RegisterConcatenatedSyntaxOutputs(generateManagedToNativeStub, "ManagedToNativeStubs.g.cs"); // Filter the list of all stubs to only the stubs that requested unmanaged-to-managed stub generation. IncrementalValuesProvider nativeToManagedStubContexts = generateStubInformation .Where(data => data.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional); - // Generate the code for the unmanaged-to-managed stubs and the diagnostics from code-generation. - IncrementalValuesProvider<(MemberDeclarationSyntax, ImmutableArray)> generateNativeToManagedStub = nativeToManagedStubContexts + // Generate the code for the unmanaged-to-managed stubs. + IncrementalValuesProvider generateNativeToManagedStub = nativeToManagedStubContexts .Select( static (data, ct) => GenerateNativeToManagedStub(data) ) - .WithComparer(Comparers.GeneratedSyntax) + .WithComparer(SyntaxEquivalentComparer.Instance) .WithTrackingName(StepNames.GenerateNativeToManagedStub); - context.RegisterDiagnostics(generateNativeToManagedStub.SelectMany((stubInfo, ct) => stubInfo.Item2)); - - context.RegisterConcatenatedSyntaxOutputs(generateNativeToManagedStub.Select((data, ct) => data.Item1), "NativeToManagedStubs.g.cs"); + context.RegisterConcatenatedSyntaxOutputs(generateNativeToManagedStub, "NativeToManagedStubs.g.cs"); // Generate the native interface metadata for each interface that contains a method with the [VirtualMethodIndex] attribute. IncrementalValuesProvider generateNativeInterface = generateStubInformation @@ -195,7 +180,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }; } - private static SourceAvailableIncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct) + internal static SourceAvailableIncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct) { ct.ThrowIfCancellationRequested(); INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute); @@ -362,66 +347,26 @@ private static MarshallingInfo CreateExceptionMarshallingInfo(AttributeData virt return NoMarshallingInfo.Instance; } - private static (MemberDeclarationSyntax, ImmutableArray) GenerateManagedToNativeStub( + private static MemberDeclarationSyntax GenerateManagedToNativeStub( SourceAvailableIncrementalMethodStubGenerationContext methodStub) { - var (stub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(methodStub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver); + var (stub, _) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(methodStub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver); - return ( - methodStub.ContainingSyntaxContext.AddContainingSyntax( + return methodStub.ContainingSyntaxContext.AddContainingSyntax( NativeTypeContainingSyntax) .WrapMemberInContainingSyntaxWithUnsafeModifier( - stub), - methodStub.Diagnostics.Array.AddRange(diagnostics)); + stub); } - private static (MemberDeclarationSyntax, ImmutableArray) GenerateNativeToManagedStub( + private static MemberDeclarationSyntax GenerateNativeToManagedStub( SourceAvailableIncrementalMethodStubGenerationContext methodStub) { - var (stub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(methodStub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver); + var (stub, _) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(methodStub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver); - return ( - methodStub.ContainingSyntaxContext.AddContainingSyntax( + return methodStub.ContainingSyntaxContext.AddContainingSyntax( NativeTypeContainingSyntax) .WrapMemberInContainingSyntaxWithUnsafeModifier( - stub), - methodStub.Diagnostics.Array.AddRange(diagnostics)); - } - - private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax methodSyntax, IMethodSymbol method) - { - // Verify the method has no generic types or defined implementation - // and is not marked static or sealed - if (methodSyntax.TypeParameterList is not null - || methodSyntax.Body is not null - || methodSyntax.Modifiers.Any(SyntaxKind.StaticKeyword) - || methodSyntax.Modifiers.Any(SyntaxKind.SealedKeyword)) - { - return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, methodSyntax.Identifier.GetLocation(), method.Name); - } - - // Verify that the types the method is declared in are marked partial. - for (SyntaxNode? parentNode = methodSyntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent) - { - if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword)) - { - return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers, methodSyntax.Identifier.GetLocation(), method.Name, typeDecl.Identifier); - } - } - - // Verify the method does not have a ref return - if (method.ReturnsByRef || method.ReturnsByRefReadonly) - { - return Diagnostic.Create(GeneratorDiagnostics.ReturnConfigurationNotSupported, methodSyntax.Identifier.GetLocation(), "ref return", method.ToDisplayString()); - } - - // Verify there is an [UnmanagedObjectUnwrapperAttribute] - if (!method.ContainingType.GetAttributes().Any(att => att.AttributeClass.IsOfType(TypeNames.UnmanagedObjectUnwrapperAttribute))) - { - return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingUnmanagedObjectUnwrapperAttribute, methodSyntax.Identifier.GetLocation(), method.Name); - } - - return null; + stub); } private static MemberDeclarationSyntax GenerateNativeInterfaceMetadata(ContainingSyntaxContext context) diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/AddMarshalAsToElementTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/AddMarshalAsToElementTests.cs index d7858a7fc75feb..21c84aef5698ed 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/AddMarshalAsToElementTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/AddMarshalAsToElementTests.cs @@ -10,7 +10,7 @@ using Xunit; using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpCodeFixVerifier< - Microsoft.CodeAnalysis.Testing.EmptyDiagnosticAnalyzer, + Microsoft.Interop.Analyzers.ComInterfaceGeneratorDiagnosticsAnalyzer, Microsoft.Interop.Analyzers.AddMarshalAsToElementFixer>; namespace ComInterfaceGenerator.Unit.Tests diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ByValueContentsMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ByValueContentsMarshalling.cs index 3bd2bae74d92d0..880c6a5f9bc7a0 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ByValueContentsMarshalling.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ByValueContentsMarshalling.cs @@ -10,7 +10,7 @@ using Xunit; using static Microsoft.Interop.UnitTests.TestUtils; using StringMarshalling = System.Runtime.InteropServices.StringMarshalling; -using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; +using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; namespace ComInterfaceGenerator.Unit.Tests { diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs index a908d720673dfc..19aefa55d689d2 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs @@ -15,7 +15,7 @@ using Xunit; using static Microsoft.Interop.UnitTests.TestUtils; using StringMarshalling = System.Runtime.InteropServices.StringMarshalling; -using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; +using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; namespace ComInterfaceGenerator.Unit.Tests { diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/TargetSignatureTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/TargetSignatureTests.cs index 2e7bb5815a27fd..e9729c57a075ae 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/TargetSignatureTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/TargetSignatureTests.cs @@ -15,7 +15,7 @@ using Microsoft.Interop; using Xunit; -using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; +using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; namespace ComInterfaceGenerator.Unit.Tests {