diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/AddGeneratedComClassFixer.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/AddGeneratedComClassFixer.cs index f61e6c61a1279b..e882713eef278e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/AddGeneratedComClassFixer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/AddGeneratedComClassFixer.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Immutable; using System.Composition; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis; @@ -21,8 +22,12 @@ public class AddGeneratedComClassFixer : ConvertToSourceGeneratedInteropFixer protected override string BaseEquivalenceKey => nameof(AddGeneratedComClassFixer); - private static Task AddGeneratedComClassAsync(DocumentEditor editor, SyntaxNode node) + private static async Task AddGeneratedComClassAsync(SolutionEditor solutionEditor, DocumentId documentId, SyntaxNode node, CancellationToken ct) { + var editor = await solutionEditor.GetDocumentEditorAsync(documentId, ct).ConfigureAwait(false); + + var declaringType = editor.SemanticModel.GetDeclaredSymbol(node, ct) as INamedTypeSymbol; + editor.ReplaceNode(node, (node, gen) => { var attribute = gen.Attribute(gen.TypeExpression(editor.SemanticModel.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComClassAttribute)).WithAdditionalAnnotations(Simplifier.AddImportsAnnotation)); @@ -37,12 +42,36 @@ private static Task AddGeneratedComClassAsync(DocumentEditor editor, SyntaxNode MakeNodeParentsPartial(editor, node); - return Task.CompletedTask; + if (declaringType is not null) + { + var comVisibleAttributeType = editor.SemanticModel.Compilation.GetBestTypeByMetadataName(TypeNames.System_Runtime_InteropServices_ComVisibleAttribute); + if (comVisibleAttributeType is not null) + { + var comVisibleAttributes = declaringType.GetAttributes().Where(attr => + SymbolEqualityComparer.Default.Equals(attr.AttributeClass, comVisibleAttributeType) + && attr.ConstructorArguments.Length == 1 + && attr.ConstructorArguments[0].Value is true).ToArray(); + + foreach (var comVisibleAttr in comVisibleAttributes) + { + if (comVisibleAttr.ApplicationSyntaxReference is { } syntaxRef) + { + var comVisibleAttrSyntax = await syntaxRef.GetSyntaxAsync(ct).ConfigureAwait(false); + var attrDocumentId = solutionEditor.OriginalSolution.GetDocumentId(syntaxRef.SyntaxTree); + if (attrDocumentId is not null) + { + var attrEditor = await solutionEditor.GetDocumentEditorAsync(attrDocumentId, ct).ConfigureAwait(false); + attrEditor.RemoveNode(comVisibleAttrSyntax); + } + } + } + } + } } - protected override Func CreateFixForSelectedOptions(SyntaxNode node, ImmutableDictionary selectedOptions) + protected override Func CreateFixForSelectedOptions(SyntaxNode node, ImmutableDictionary selectedOptions) { - return (editor, _) => AddGeneratedComClassAsync(editor, node); + return (solutionEditor, documentId, ct) => AddGeneratedComClassAsync(solutionEditor, documentId, node, ct); } protected override string GetDiagnosticTitle(ImmutableDictionary selectedOptions) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ConvertComImportToGeneratedComInterfaceFixer.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ConvertComImportToGeneratedComInterfaceFixer.cs index f0a06e7fad561e..60ae005cd24028 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ConvertComImportToGeneratedComInterfaceFixer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Analyzers/ConvertComImportToGeneratedComInterfaceFixer.cs @@ -32,12 +32,16 @@ protected override string GetDiagnosticTitle(ImmutableDictionary : SR.ConvertToGeneratedComInterfaceTitle; } - protected override Func CreateFixForSelectedOptions(SyntaxNode node, ImmutableDictionary selectedOptions) + protected override Func CreateFixForSelectedOptions(SyntaxNode node, ImmutableDictionary selectedOptions) { bool mayRequireAdditionalWork = selectedOptions.TryGetValue(Option.MayRequireAdditionalWork, out Option mayRequireAdditionalWorkOption) && mayRequireAdditionalWorkOption is Option.Bool(true); bool addStringMarshalling = selectedOptions.TryGetValue(AddStringMarshallingOption, out Option addStringMarshallingOption) && addStringMarshallingOption is Option.Bool(true); - return (editor, ct) => ConvertComImportToGeneratedComInterfaceAsync(editor, node, mayRequireAdditionalWork, addStringMarshalling, ct); + return async (solutionEditor, documentId, ct) => + { + var editor = await solutionEditor.GetDocumentEditorAsync(documentId, ct).ConfigureAwait(false); + await ConvertComImportToGeneratedComInterfaceAsync(editor, node, mayRequireAdditionalWork, addStringMarshalling, ct).ConfigureAwait(false); + }; } protected override ImmutableDictionary ParseOptionsFromDiagnostic(Diagnostic diagnostic) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/ConvertToSourceGeneratedInteropFixer.cs b/src/libraries/System.Runtime.InteropServices/gen/Common/ConvertToSourceGeneratedInteropFixer.cs index dc2cef27f3b926..731eacf957f706 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/ConvertToSourceGeneratedInteropFixer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/ConvertToSourceGeneratedInteropFixer.cs @@ -30,7 +30,7 @@ protected virtual IEnumerable CreateAllFixes yield return new ConvertToSourceGeneratedInteropFix(CreateFixForSelectedOptions(node, options), options); } - protected abstract Func CreateFixForSelectedOptions(SyntaxNode node, ImmutableDictionary selectedOptions); + protected abstract Func CreateFixForSelectedOptions(SyntaxNode node, ImmutableDictionary selectedOptions); protected abstract string GetDiagnosticTitle(ImmutableDictionary selectedOptions); @@ -103,11 +103,10 @@ private ImmutableDictionary GetOptionsForIndividualFix(Immutable return CombineOptions(fixAllOptions, ParseOptionsFromDiagnostic(diagnostic)); } - private static async Task ApplyActionAndEnableUnsafe(Solution solution, DocumentId documentId, Func documentBasedFix, CancellationToken ct) + private static async Task ApplyActionAndEnableUnsafe(Solution solution, DocumentId documentId, Func solutionBasedFix, CancellationToken ct) { var editor = new SolutionEditor(solution); - var docEditor = await editor.GetDocumentEditorAsync(documentId, ct).ConfigureAwait(false); - await documentBasedFix(docEditor, ct).ConfigureAwait(false); + await solutionBasedFix(editor, documentId, ct).ConfigureAwait(false); var docProjectId = documentId.ProjectId; var updatedSolution = editor.GetChangedSolution(); @@ -148,11 +147,9 @@ public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context) GetDiagnosticTitle(fix.SelectedOptions), async ct => { - DocumentEditor editor = await DocumentEditor.CreateAsync(doc, ct).ConfigureAwait(false); - - await fix.ApplyFix(editor, ct).ConfigureAwait(false); - - return editor.GetChangedDocument(); + var solutionEditor = new SolutionEditor(doc.Project.Solution); + await fix.ApplyFix(solutionEditor, doc.Id, ct).ConfigureAwait(false); + return solutionEditor.GetChangedSolution(); }, Option.CreateEquivalenceKeyFromOptions(BaseEquivalenceKey, fix.SelectedOptions)), diagnostic); @@ -161,7 +158,7 @@ public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context) } } - protected record struct ConvertToSourceGeneratedInteropFix(Func ApplyFix, ImmutableDictionary SelectedOptions); + protected record struct ConvertToSourceGeneratedInteropFix(Func ApplyFix, ImmutableDictionary SelectedOptions); private sealed class CustomFixAllProvider : FixAllProvider { @@ -196,14 +193,13 @@ private sealed class CustomFixAllProvider : FixAllProvider continue; } DocumentId documentId = solutionEditor.OriginalSolution.GetDocumentId(diagnostic.Location.SourceTree)!; - DocumentEditor editor = await solutionEditor.GetDocumentEditorAsync(documentId, ct).ConfigureAwait(false); SyntaxNode root = await diagnostic.Location.SourceTree.GetRootAsync(ct).ConfigureAwait(false); SyntaxNode node = root.FindNode(diagnostic.Location.SourceSpan); - var documentBasedFix = codeFixProvider.CreateFixForSelectedOptions(node, codeFixProvider.GetOptionsForIndividualFix(options, diagnostic)); + var solutionBasedFix = codeFixProvider.CreateFixForSelectedOptions(node, codeFixProvider.GetOptionsForIndividualFix(options, diagnostic)); - await documentBasedFix(editor, ct).ConfigureAwait(false); + await solutionBasedFix(solutionEditor, documentId, ct).ConfigureAwait(false); // Record this project as a project we need to allow unsafe blocks in. projectsToAddUnsafe.Add(solutionEditor.OriginalSolution.GetDocument(documentId).Project); diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/ConvertToLibraryImportFixer.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/ConvertToLibraryImportFixer.cs index ef8c15d621c148..5d2e1e2d5f5589 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/ConvertToLibraryImportFixer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/ConvertToLibraryImportFixer.cs @@ -83,13 +83,16 @@ protected override IEnumerable CreateAllFixe var selectedOptions = options.Remove(CharSetOption); yield return new ConvertToSourceGeneratedInteropFix( - (editor, ct) => - ConvertToLibraryImport( + async (solutionEditor, documentId, ct) => + { + var editor = await solutionEditor.GetDocumentEditorAsync(documentId, ct).ConfigureAwait(false); + await ConvertToLibraryImport( editor, node, warnForAdditionalWork, null, - ct), + ct).ConfigureAwait(false); + }, selectedOptions); if (charSet is not null) @@ -102,41 +105,50 @@ protected override IEnumerable CreateAllFixe if (charSet is CharSet.None or CharSet.Ansi or CharSet.Auto) { yield return new ConvertToSourceGeneratedInteropFix( - (editor, ct) => - ConvertToLibraryImport( + async (solutionEditor, documentId, ct) => + { + var editor = await solutionEditor.GetDocumentEditorAsync(documentId, ct).ConfigureAwait(false); + await ConvertToLibraryImport( editor, node, warnForAdditionalWork, 'A', - ct), + ct).ConfigureAwait(false); + }, selectedOptions.Add(SelectedSuffixOption, new Option.String("A"))); } if (charSet is CharSet.Unicode or CharSet.Auto) { yield return new ConvertToSourceGeneratedInteropFix( - (editor, ct) => - ConvertToLibraryImport( + async (solutionEditor, documentId, ct) => + { + var editor = await solutionEditor.GetDocumentEditorAsync(documentId, ct).ConfigureAwait(false); + await ConvertToLibraryImport( editor, node, warnForAdditionalWork, 'W', - ct), + ct).ConfigureAwait(false); + }, selectedOptions.Add(SelectedSuffixOption, new Option.String("W"))); } } } - protected override Func CreateFixForSelectedOptions(SyntaxNode node, ImmutableDictionary selectedOptions) + protected override Func CreateFixForSelectedOptions(SyntaxNode node, ImmutableDictionary selectedOptions) { bool warnForAdditionalWork = selectedOptions.TryGetValue(Option.MayRequireAdditionalWork, out Option mayRequireAdditionalWork) && mayRequireAdditionalWork is Option.Bool(true); char? suffix = selectedOptions.TryGetValue(SelectedSuffixOption, out Option selectedSuffixOption) && selectedSuffixOption is Option.String(string selectedSuffix) ? selectedSuffix[0] : null; - return (editor, ct) => - ConvertToLibraryImport( + return async (solutionEditor, documentId, ct) => + { + var editor = await solutionEditor.GetDocumentEditorAsync(documentId, ct).ConfigureAwait(false); + await ConvertToLibraryImport( editor, node, warnForAdditionalWork, suffix, - ct); + ct).ConfigureAwait(false); + }; } private static string AppendSuffix(string entryPoint, char? entryPointSuffix) diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/AddGeneratedComClassTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/AddGeneratedComClassTests.cs index 5ec79561e05dff..03b42304368e48 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/AddGeneratedComClassTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/AddGeneratedComClassTests.cs @@ -121,6 +121,190 @@ partial class C : J await VerifyCS.VerifyCodeFixAsync(source, fixedSource); } + [Fact] + public async Task TypeWithComVisibleTrue_RemovesComVisibleAttribute() + { + string source = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + [Guid("0B7171CD-04A3-41B6-AD10-FE86D52197DD")] + public partial interface I + { + } + + [ComVisible(true)] + class [|C|] : I + { + } + """; + + string fixedSource = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + [Guid("0B7171CD-04A3-41B6-AD10-FE86D52197DD")] + public partial interface I + { + } + + [GeneratedComClass] + partial class C : I + { + } + """; + + await VerifyCS.VerifyCodeFixAsync(source, fixedSource); + } + + [Fact] + public async Task TypeWithComVisibleFalse_PreservesComVisibleAttribute() + { + string source = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + [Guid("0B7171CD-04A3-41B6-AD10-FE86D52197DD")] + public partial interface I + { + } + + [ComVisible(false)] + class [|C|] : I + { + } + """; + + string fixedSource = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + [Guid("0B7171CD-04A3-41B6-AD10-FE86D52197DD")] + public partial interface I + { + } + + [ComVisible(false)] + [GeneratedComClass] + partial class C : I + { + } + """; + + await VerifyCS.VerifyCodeFixAsync(source, fixedSource); + } + + [Fact] + public async Task TypeWithComVisibleTrueOnSecondPartialDeclaration_RemovesComVisibleAttribute() + { + string source = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + [Guid("0B7171CD-04A3-41B6-AD10-FE86D52197DD")] + public partial interface I + { + } + + partial class [|C|] : I + { + } + + [ComVisible(true)] + partial class C + { + } + """; + + string fixedSource = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + [Guid("0B7171CD-04A3-41B6-AD10-FE86D52197DD")] + public partial interface I + { + } + + [GeneratedComClass] + partial class C : I + { + } + + partial class C + { + } + """; + + await VerifyCS.VerifyCodeFixAsync(source, fixedSource); + } + + [Fact] + public async Task TypeWithComVisibleTrueInSeparateFile_RemovesComVisibleAttribute() + { + string mainSource = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + [Guid("0B7171CD-04A3-41B6-AD10-FE86D52197DD")] + public partial interface I + { + } + + partial class [|C|] : I + { + } + """; + + string mainFixedSource = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + [Guid("0B7171CD-04A3-41B6-AD10-FE86D52197DD")] + public partial interface I + { + } + + [GeneratedComClass] + partial class C : I + { + } + """; + + string secondSource = """ + using System.Runtime.InteropServices; + + [ComVisible(true)] + partial class C + { + } + """; + + string secondFixedSource = """ + using System.Runtime.InteropServices; + + partial class C + { + } + """; + + var test = new VerifyCS.Test + { + TestCode = mainSource, + FixedCode = mainFixedSource, + }; + test.TestState.Sources.Add(("C.Second.cs", secondSource)); + test.FixedState.Sources.Add(("C.Second.cs", secondFixedSource)); + await test.RunAsync(); + } + [Fact] public async Task TypeThatInheritsFromGeneratedComClassType_ReportsDiagnostic() {