diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Analyzers/ConvertToGeneratedDllImportFixer.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Analyzers/ConvertToGeneratedDllImportFixer.cs index 829d508a72646a..bf0c644d415f28 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Analyzers/ConvertToGeneratedDllImportFixer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Analyzers/ConvertToGeneratedDllImportFixer.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; @@ -14,7 +15,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Editing; - +using Microsoft.CodeAnalysis.FindSymbols; using static Microsoft.Interop.Analyzers.AnalyzerDiagnostics; namespace Microsoft.Interop.Analyzers @@ -113,6 +114,11 @@ private async Task ConvertToGeneratedDllImport( // Replace DllImport with GeneratedDllImport SyntaxNode generatedDeclaration = generator.ReplaceNode(methodSyntax, dllImportSyntax, generatedDllImportSyntax); + if (!methodSymbol.MethodImplementationFlags.HasFlag(System.Reflection.MethodImplAttributes.PreserveSig)) + { + generatedDeclaration = await RemoveNoPreserveSigTransform(editor, generatedDeclaration, methodSymbol, cancellationToken).ConfigureAwait(false); + } + if (unmanagedCallConvAttributeMaybe is not null) { generatedDeclaration = generator.AddAttributes(generatedDeclaration, unmanagedCallConvAttributeMaybe); @@ -131,6 +137,175 @@ private async Task ConvertToGeneratedDllImport( return editor.GetChangedDocument(); } + private async Task RemoveNoPreserveSigTransform(DocumentEditor editor, SyntaxNode generatedDeclaration, IMethodSymbol methodSymbol, CancellationToken cancellationToken) + { + Document? document = editor.OriginalDocument; + IEnumerable? referencedSymbols = await SymbolFinder.FindReferencesAsync( + methodSymbol, document.Project.Solution, cancellationToken).ConfigureAwait(false); + + SyntaxNode root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false); + + // Sometimes we can't validate that we've fixed all callers, so we warn the user that this fix might produce invalid code. + bool shouldWarn = false; + + List<(SyntaxNode invocation, Func action)> invocations = new(); + + foreach (ReferencedSymbol? referencedSymbol in referencedSymbols) + { + foreach (ReferenceLocation location in referencedSymbol.Locations) + { + if (!location.Document.Id.Equals(document.Id)) + { + shouldWarn = true; + continue; + } + // We limited the search scope to the single document, + // so all reference should be in the same tree. + SyntaxNode? referenceNode = root.FindNode(location.Location.SourceSpan); + if (referenceNode is not IdentifierNameSyntax identifierNode) + { + // Unexpected scenario, skip and warn. + shouldWarn = true; + continue; + } + + InvocationExpressionSyntax? invocation = identifierNode switch + { + { Parent: InvocationExpressionSyntax invocationInScope } => invocationInScope, + { Parent: MemberAccessExpressionSyntax { Parent: InvocationExpressionSyntax invocationOnType } } => invocationOnType, + _ => null! + }; + + if (invocation is null) + { + // We won't be able to fix non-invocation references, + // e.g. creating a delegate. + shouldWarn = true; + continue; + } + + if (methodSymbol.ReturnsVoid) + { + // There is no return value, so we don't need to add any arguments to the invocation. + // We only need to wrap the invocation with a call to ThrowExceptionForHR + invocations.Add((invocation, WrapInvocationWithHRExceptionThrow)); + } + else if (invocation.Parent.IsKind(SyntaxKind.ExpressionStatement)) + { + // The return value isn't used, so discard the new out parameter value + invocations.Add((invocation, + (node, generator) => + { + return WrapInvocationWithHRExceptionThrow( + ((InvocationExpressionSyntax)node).AddArgumentListArguments( + SyntaxFactory.Argument(SyntaxFactory.IdentifierName( + SyntaxFactory.Identifier( + SyntaxFactory.TriviaList(), + SyntaxKind.UnderscoreToken, + "_", + "_", + SyntaxFactory.TriviaList()))) + .WithRefKindKeyword(SyntaxFactory.Token(SyntaxKind.OutKeyword))), + generator); + } + )); + } + else if (invocation.Parent.IsKind(SyntaxKind.EqualsValueClause)) + { + LocalDeclarationStatementSyntax declaration = invocation.FirstAncestorOrSelf(); + if (declaration.IsKind(SyntaxKind.FieldDeclaration) || declaration.IsKind(SyntaxKind.EventFieldDeclaration)) + { + // We can't fix initalizations without introducing or prepending to a static constructor + // for what is an unlikely scenario. + continue; + } + if (declaration.Declaration.Variables.Count != 1) + { + // We can't handle multiple variable initializations easily + continue; + } + // The result was used to initialize a variable, + // so initialize the variable inline + invocations.Add((declaration, + (node, generator) => + { + var declaration = (LocalDeclarationStatementSyntax)node; + var invocation = (InvocationExpressionSyntax)declaration.Declaration.Variables[0].Initializer.Value; + return generator.ExpressionStatement( + WrapInvocationWithHRExceptionThrow( + invocation.AddArgumentListArguments( + SyntaxFactory.Argument(SyntaxFactory.DeclarationExpression( + declaration.Declaration.Type, + SyntaxFactory.SingleVariableDesignation( + declaration.Declaration.Variables[0].Identifier.WithoutTrivia()))) + .WithRefKindKeyword(SyntaxFactory.Token(SyntaxKind.OutKeyword))), + generator)); + } + )); + } + else if (invocation.Parent.IsKind(SyntaxKind.SimpleAssignmentExpression) && invocation.Parent.Parent.IsKind(SyntaxKind.ExpressionStatement)) + { + invocations.Add((invocation.Parent, + (node, generator) => + { + var assignment = (AssignmentExpressionSyntax)node; + var invocation = (InvocationExpressionSyntax)assignment.Right; + return WrapInvocationWithHRExceptionThrow( + invocation.AddArgumentListArguments( + SyntaxFactory.Argument(generator.ClearTrivia(assignment.Left)) + .WithRefKindKeyword(SyntaxFactory.Token(SyntaxKind.OutKeyword))), + generator); + } + )); + } + else + { + shouldWarn = true; + } + } + } + + foreach ((SyntaxNode node, Func action) nodesWithReplaceAction in invocations) + { + editor.ReplaceNode( + nodesWithReplaceAction.node, (node, generator) => + { + return nodesWithReplaceAction.action(node, generator); + }); + } + + SyntaxNode noPreserveSigDeclaration = editor.Generator.WithType( + generatedDeclaration, + editor.Generator.TypeExpression(editor.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Int32))); + + if (!methodSymbol.ReturnsVoid) + { + noPreserveSigDeclaration = editor.Generator.AddParameters( + noPreserveSigDeclaration, + new[] + { + editor.Generator.ParameterDeclaration("@return", editor.Generator.GetType(generatedDeclaration), refKind: RefKind.Out) + }); + } + + if (shouldWarn) + { + noPreserveSigDeclaration = noPreserveSigDeclaration.WithAdditionalAnnotations(WarningAnnotation.Create(Resources.ConvertNoPreserveSigDllImportToGeneratedMayProduceInvalidCode)); + } + return noPreserveSigDeclaration; + + SyntaxNode WrapInvocationWithHRExceptionThrow(SyntaxNode node, SyntaxGenerator generator) + { + return generator.InvocationExpression( + generator.MemberAccessExpression( + generator.NameExpression( + editor.SemanticModel.Compilation.GetTypeByMetadataName( + TypeNames.System_Runtime_InteropServices_Marshal)), + "ThrowExceptionForHR"), + node); + } + } + private SyntaxNode GetGeneratedDllImportAttribute( DocumentEditor editor, SyntaxGenerator generator, @@ -181,6 +356,11 @@ private SyntaxNode GetGeneratedDllImportAttribute( argumentsToRemove.Add(argument); } } + else if (IsMatchingNamedArg(attrArg, nameof(DllImportAttribute.PreserveSig))) + { + // We transform the signature for PreserveSig, so we can remove the argument + argumentsToRemove.Add(argument); + } } generatedDllImportSyntax = generator.RemoveNodes(generatedDllImportSyntax, argumentsToRemove); diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Resources.Designer.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Resources.Designer.cs index a8c33b93818ec0..097bc342b40648 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Resources.Designer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Resources.Designer.cs @@ -222,6 +222,15 @@ internal static string ConstantAndElementCountInfoDisallowed { } } + /// + /// Looks up a localized string similar to Automatically converting a P/Invoke with 'PreserveSig' set to 'false' to a source-generated P/Invoke may produce invalid code. + /// + internal static string ConvertNoPreserveSigDllImportToGeneratedMayProduceInvalidCode { + get { + return ResourceManager.GetString("ConvertNoPreserveSigDllImportToGeneratedMayProduceInvalidCode", resourceCulture); + } + } + /// /// Looks up a localized string similar to Convert to 'GeneratedDllImport'. /// diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Resources.resx b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Resources.resx index 8a9d0f5b7e12e7..dfb4176015d592 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Resources.resx +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Resources.resx @@ -171,6 +171,9 @@ Only one of 'ConstantElementCount' or 'ElementCountInfo' may be used in a 'MarshalUsingAttribute' for a given 'ElementIndirectionLevel' + + Automatically converting a P/Invoke with 'PreserveSig' set to 'false' to a source-generated P/Invoke may produce invalid code + Convert to 'GeneratedDllImport' diff --git a/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/ConvertToGeneratedDllImportFixerTests.cs b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/ConvertToGeneratedDllImportFixerTests.cs index a818ba851eeb3c..2335db72848ea8 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/ConvertToGeneratedDllImportFixerTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/ConvertToGeneratedDllImportFixerTests.cs @@ -149,6 +149,9 @@ partial class Test [DllImport(""DoesNotExist"", ThrowOnUnmappableChar = false)] public static extern int [|Method2|](out int ret); + + [DllImport(""DoesNotExist"", PreserveSig = true)] + public static extern int [|Method3|](out int ret); }}"; // Fixed source will have CS8795 (Partial method must have an implementation) without generator run string fixedSource = @$" @@ -160,6 +163,9 @@ partial class Test [GeneratedDllImport(""DoesNotExist"")] public static partial int {{|CS8795:Method2|}}(out int ret); + + [GeneratedDllImport(""DoesNotExist"")] + public static partial int {{|CS8795:Method3|}}(out int ret); }}"; await VerifyCS.VerifyCodeFixAsync( source, @@ -239,5 +245,48 @@ await VerifyCS.VerifyCodeFixAsync( source, fixedSource); } + + [ConditionalFact] + public async Task PreserveSigFalseSignatureModified() + { + string source = @" +using System.Runtime.InteropServices; +partial class Test +{ + [DllImport(""DoesNotExist"", PreserveSig = false)] + public static extern void [|VoidMethod|](bool param); + [DllImport(""DoesNotExist"", PreserveSig = false)] + public static extern long [|Method|](bool param); + + public static void Code() + { + Test.VoidMethod(true); + Test.Method(true); + long value = Test.Method(true); + value = Test.Method(true); + } +}"; + // Fixed source will have CS8795 (Partial method must have an implementation) without generator run + string fixedSource = @" +using System.Runtime.InteropServices; +partial class Test +{ + [GeneratedDllImport(""DoesNotExist"")] + public static partial int {|CS8795:VoidMethod|}(bool param); + [GeneratedDllImport(""DoesNotExist"")] + public static partial int {|CS8795:Method|}(bool param, out long @return); + + public static void Code() + { + Marshal.ThrowExceptionForHR(Test.VoidMethod(true)); + Marshal.ThrowExceptionForHR(Test.Method(true, out _)); + Marshal.ThrowExceptionForHR(Test.Method(true, out long value)); + Marshal.ThrowExceptionForHR(Test.Method(true, out value)); + } +}"; + await VerifyCS.VerifyCodeFixAsync( + source, + fixedSource); + } } }