Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
Expand All @@ -14,7 +15,7 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;

using Microsoft.CodeAnalysis.FindSymbols;
using static Microsoft.Interop.Analyzers.AnalyzerDiagnostics;

namespace Microsoft.Interop.Analyzers
Expand Down Expand Up @@ -113,6 +114,11 @@ private async Task<Document> ConvertToGeneratedDllImport(
// Replace DllImport with GeneratedDllImport
SyntaxNode generatedDeclaration = generator.ReplaceNode(methodSyntax, dllImportSyntax, generatedDllImportSyntax);

if (!methodSymbol.MethodImplementationFlags.HasFlag(System.Reflection.MethodImplAttributes.PreserveSig))
{
generatedDeclaration = await RemoveNoPreserveSigTransform(editor, generatedDeclaration, methodSymbol, cancellationToken).ConfigureAwait(false);
}

if (unmanagedCallConvAttributeMaybe is not null)
{
generatedDeclaration = generator.AddAttributes(generatedDeclaration, unmanagedCallConvAttributeMaybe);
Expand All @@ -131,6 +137,175 @@ private async Task<Document> ConvertToGeneratedDllImport(
return editor.GetChangedDocument();
}

private async Task<SyntaxNode> RemoveNoPreserveSigTransform(DocumentEditor editor, SyntaxNode generatedDeclaration, IMethodSymbol methodSymbol, CancellationToken cancellationToken)
{
Document? document = editor.OriginalDocument;
IEnumerable<ReferencedSymbol>? referencedSymbols = await SymbolFinder.FindReferencesAsync(
methodSymbol, document.Project.Solution, cancellationToken).ConfigureAwait(false);

SyntaxNode root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);

// Sometimes we can't validate that we've fixed all callers, so we warn the user that this fix might produce invalid code.
bool shouldWarn = false;

List<(SyntaxNode invocation, Func<SyntaxNode, SyntaxGenerator, SyntaxNode> action)> invocations = new();

foreach (ReferencedSymbol? referencedSymbol in referencedSymbols)
{
foreach (ReferenceLocation location in referencedSymbol.Locations)
{
if (!location.Document.Id.Equals(document.Id))
{
shouldWarn = true;
continue;
}
// We limited the search scope to the single document,
// so all reference should be in the same tree.
SyntaxNode? referenceNode = root.FindNode(location.Location.SourceSpan);
if (referenceNode is not IdentifierNameSyntax identifierNode)
{
// Unexpected scenario, skip and warn.
shouldWarn = true;
continue;
}

InvocationExpressionSyntax? invocation = identifierNode switch
{
{ Parent: InvocationExpressionSyntax invocationInScope } => invocationInScope,
{ Parent: MemberAccessExpressionSyntax { Parent: InvocationExpressionSyntax invocationOnType } } => invocationOnType,
_ => null!
};

if (invocation is null)
{
// We won't be able to fix non-invocation references,
// e.g. creating a delegate.
shouldWarn = true;
continue;
}

if (methodSymbol.ReturnsVoid)
{
// There is no return value, so we don't need to add any arguments to the invocation.
// We only need to wrap the invocation with a call to ThrowExceptionForHR
invocations.Add((invocation, WrapInvocationWithHRExceptionThrow));
}
else if (invocation.Parent.IsKind(SyntaxKind.ExpressionStatement))
{
// The return value isn't used, so discard the new out parameter value
invocations.Add((invocation,
(node, generator) =>
{
return WrapInvocationWithHRExceptionThrow(
((InvocationExpressionSyntax)node).AddArgumentListArguments(
SyntaxFactory.Argument(SyntaxFactory.IdentifierName(
SyntaxFactory.Identifier(
SyntaxFactory.TriviaList(),
SyntaxKind.UnderscoreToken,
"_",
"_",
SyntaxFactory.TriviaList())))
.WithRefKindKeyword(SyntaxFactory.Token(SyntaxKind.OutKeyword))),
generator);
}
));
}
else if (invocation.Parent.IsKind(SyntaxKind.EqualsValueClause))
{
LocalDeclarationStatementSyntax declaration = invocation.FirstAncestorOrSelf<LocalDeclarationStatementSyntax>();
if (declaration.IsKind(SyntaxKind.FieldDeclaration) || declaration.IsKind(SyntaxKind.EventFieldDeclaration))
{
// We can't fix initalizations without introducing or prepending to a static constructor
// for what is an unlikely scenario.
continue;
}
if (declaration.Declaration.Variables.Count != 1)
{
// We can't handle multiple variable initializations easily
continue;
}
// The result was used to initialize a variable,
// so initialize the variable inline
invocations.Add((declaration,
(node, generator) =>
{
var declaration = (LocalDeclarationStatementSyntax)node;
var invocation = (InvocationExpressionSyntax)declaration.Declaration.Variables[0].Initializer.Value;
return generator.ExpressionStatement(
WrapInvocationWithHRExceptionThrow(
invocation.AddArgumentListArguments(
SyntaxFactory.Argument(SyntaxFactory.DeclarationExpression(
declaration.Declaration.Type,
SyntaxFactory.SingleVariableDesignation(
declaration.Declaration.Variables[0].Identifier.WithoutTrivia())))
.WithRefKindKeyword(SyntaxFactory.Token(SyntaxKind.OutKeyword))),
generator));
}
));
}
else if (invocation.Parent.IsKind(SyntaxKind.SimpleAssignmentExpression) && invocation.Parent.Parent.IsKind(SyntaxKind.ExpressionStatement))
{
invocations.Add((invocation.Parent,
(node, generator) =>
{
var assignment = (AssignmentExpressionSyntax)node;
var invocation = (InvocationExpressionSyntax)assignment.Right;
return WrapInvocationWithHRExceptionThrow(
invocation.AddArgumentListArguments(
SyntaxFactory.Argument(generator.ClearTrivia(assignment.Left))
.WithRefKindKeyword(SyntaxFactory.Token(SyntaxKind.OutKeyword))),
generator);
}
));
}
else
{
shouldWarn = true;
}
}
}

foreach ((SyntaxNode node, Func<SyntaxNode, SyntaxGenerator, SyntaxNode> action) nodesWithReplaceAction in invocations)
{
editor.ReplaceNode(
nodesWithReplaceAction.node, (node, generator) =>
{
return nodesWithReplaceAction.action(node, generator);
});
}

SyntaxNode noPreserveSigDeclaration = editor.Generator.WithType(
generatedDeclaration,
editor.Generator.TypeExpression(editor.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Int32)));

if (!methodSymbol.ReturnsVoid)
{
noPreserveSigDeclaration = editor.Generator.AddParameters(
noPreserveSigDeclaration,
new[]
{
editor.Generator.ParameterDeclaration("@return", editor.Generator.GetType(generatedDeclaration), refKind: RefKind.Out)
});
}

if (shouldWarn)
{
noPreserveSigDeclaration = noPreserveSigDeclaration.WithAdditionalAnnotations(WarningAnnotation.Create(Resources.ConvertNoPreserveSigDllImportToGeneratedMayProduceInvalidCode));
}
return noPreserveSigDeclaration;

SyntaxNode WrapInvocationWithHRExceptionThrow(SyntaxNode node, SyntaxGenerator generator)
{
return generator.InvocationExpression(
generator.MemberAccessExpression(
generator.NameExpression(
editor.SemanticModel.Compilation.GetTypeByMetadataName(
TypeNames.System_Runtime_InteropServices_Marshal)),
"ThrowExceptionForHR"),
node);
}
}

private SyntaxNode GetGeneratedDllImportAttribute(
DocumentEditor editor,
SyntaxGenerator generator,
Expand Down Expand Up @@ -181,6 +356,11 @@ private SyntaxNode GetGeneratedDllImportAttribute(
argumentsToRemove.Add(argument);
}
}
else if (IsMatchingNamedArg(attrArg, nameof(DllImportAttribute.PreserveSig)))
{
// We transform the signature for PreserveSig, so we can remove the argument
argumentsToRemove.Add(argument);
}
}

generatedDllImportSyntax = generator.RemoveNodes(generatedDllImportSyntax, argumentsToRemove);
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@
<data name="ConstantAndElementCountInfoDisallowed" xml:space="preserve">
<value>Only one of 'ConstantElementCount' or 'ElementCountInfo' may be used in a 'MarshalUsingAttribute' for a given 'ElementIndirectionLevel'</value>
</data>
<data name="ConvertNoPreserveSigDllImportToGeneratedMayProduceInvalidCode" xml:space="preserve">
<value>Automatically converting a P/Invoke with 'PreserveSig' set to 'false' to a source-generated P/Invoke may produce invalid code</value>
</data>
<data name="ConvertToGeneratedDllImport" xml:space="preserve">
<value>Convert to 'GeneratedDllImport'</value>
</data>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ partial class Test

[DllImport(""DoesNotExist"", ThrowOnUnmappableChar = false)]
public static extern int [|Method2|](out int ret);

[DllImport(""DoesNotExist"", PreserveSig = true)]
public static extern int [|Method3|](out int ret);
}}";
// Fixed source will have CS8795 (Partial method must have an implementation) without generator run
string fixedSource = @$"
Expand All @@ -160,6 +163,9 @@ partial class Test

[GeneratedDllImport(""DoesNotExist"")]
public static partial int {{|CS8795:Method2|}}(out int ret);

[GeneratedDllImport(""DoesNotExist"")]
public static partial int {{|CS8795:Method3|}}(out int ret);
}}";
await VerifyCS.VerifyCodeFixAsync(
source,
Expand Down Expand Up @@ -239,5 +245,48 @@ await VerifyCS.VerifyCodeFixAsync(
source,
fixedSource);
}

[ConditionalFact]
public async Task PreserveSigFalseSignatureModified()
{
string source = @"
using System.Runtime.InteropServices;
partial class Test
{
[DllImport(""DoesNotExist"", PreserveSig = false)]
public static extern void [|VoidMethod|](bool param);
[DllImport(""DoesNotExist"", PreserveSig = false)]
public static extern long [|Method|](bool param);

public static void Code()
{
Test.VoidMethod(true);
Test.Method(true);
long value = Test.Method(true);
value = Test.Method(true);
}
}";
// Fixed source will have CS8795 (Partial method must have an implementation) without generator run
string fixedSource = @"
using System.Runtime.InteropServices;
partial class Test
{
[GeneratedDllImport(""DoesNotExist"")]
public static partial int {|CS8795:VoidMethod|}(bool param);
[GeneratedDllImport(""DoesNotExist"")]
public static partial int {|CS8795:Method|}(bool param, out long @return);

public static void Code()
{
Marshal.ThrowExceptionForHR(Test.VoidMethod(true));
Marshal.ThrowExceptionForHR(Test.Method(true, out _));
Marshal.ThrowExceptionForHR(Test.Method(true, out long value));
Marshal.ThrowExceptionForHR(Test.Method(true, out value));
}
}";
await VerifyCS.VerifyCodeFixAsync(
source,
fixedSource);
}
}
}