diff --git a/TUnit.Analyzers.CodeFixers/TimeoutCancellationTokenCodeFixProvider.cs b/TUnit.Analyzers.CodeFixers/TimeoutCancellationTokenCodeFixProvider.cs new file mode 100644 index 0000000000..69da900156 --- /dev/null +++ b/TUnit.Analyzers.CodeFixers/TimeoutCancellationTokenCodeFixProvider.cs @@ -0,0 +1,190 @@ +using System.Collections.Immutable; +using System.Composition; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CodeActions; +using Microsoft.CodeAnalysis.CodeFixes; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Formatting; + +namespace TUnit.Analyzers.CodeFixers; + +[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(TimeoutCancellationTokenCodeFixProvider)), Shared] +public class TimeoutCancellationTokenCodeFixProvider : CodeFixProvider +{ + private const string SystemThreadingNamespace = "System.Threading"; + private const string CancellationTokenTypeName = "CancellationToken"; + private const string ParameterName = "cancellationToken"; + + public sealed override ImmutableArray FixableDiagnosticIds { get; } = + ImmutableArray.Create(Rules.MissingTimeoutCancellationTokenAttributes.Id); + + public override FixAllProvider GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer; + + public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context) + { + var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false); + if (root is null) + { + return; + } + + foreach (var diagnostic in context.Diagnostics) + { + var node = root.FindNode(diagnostic.Location.SourceSpan); + var method = node.FirstAncestorOrSelf(); + if (method is null) + { + continue; + } + + context.RegisterCodeFix( + CodeAction.Create( + title: "Add CancellationToken parameter", + createChangedDocument: c => AddCancellationTokenAsync(context.Document, method, BodyMode.None, c), + equivalenceKey: "AddCancellationToken"), + diagnostic); + + // Body-modifying actions only make sense when there's a block body to prepend to. + // For expression-bodied methods we'd silently no-op, which is worse than not offering. + if (method.Body is not null) + { + context.RegisterCodeFix( + CodeAction.Create( + title: "Add CancellationToken parameter with ThrowIfCancellationRequested", + createChangedDocument: c => AddCancellationTokenAsync(context.Document, method, BodyMode.ThrowIfCancellationRequested, c), + equivalenceKey: "AddCancellationTokenWithThrow"), + diagnostic); + + context.RegisterCodeFix( + CodeAction.Create( + title: "Add CancellationToken parameter as discard", + createChangedDocument: c => AddCancellationTokenAsync(context.Document, method, BodyMode.Discard, c), + equivalenceKey: "AddCancellationTokenAsDiscard"), + diagnostic); + } + } + } + + private enum BodyMode + { + None, + ThrowIfCancellationRequested, + Discard, + } + + private static async Task AddCancellationTokenAsync( + Document document, + MethodDeclarationSyntax method, + BodyMode bodyMode, + CancellationToken cancellationToken) + { + var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false); + if (root is null) + { + return document; + } + + var parameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier(ParameterName)) + .WithType(SyntaxFactory.IdentifierName(CancellationTokenTypeName).WithTrailingTrivia(SyntaxFactory.Space)); + + MethodDeclarationSyntax updated = method + .WithParameterList(method.ParameterList.AddParameters(parameter)); + + if (bodyMode != BodyMode.None && updated.Body is { } body) + { + StatementSyntax statement = bodyMode == BodyMode.ThrowIfCancellationRequested + ? SyntaxFactory.ParseStatement($"{ParameterName}.ThrowIfCancellationRequested();") + : SyntaxFactory.ParseStatement($"_ = {ParameterName};"); + + statement = statement + .WithLeadingTrivia(SyntaxFactory.ElasticMarker) + .WithTrailingTrivia(SyntaxFactory.ElasticLineFeed); + + var newStatements = body.Statements.Insert(0, statement); + updated = updated.WithBody(body.WithStatements(newStatements)); + } + + updated = updated.WithAdditionalAnnotations(Formatter.Annotation); + + var newRoot = root.ReplaceNode(method, updated); + + if (newRoot is CompilationUnitSyntax compilationUnit + && !await IsSystemThreadingInScopeAsync(document, method, cancellationToken).ConfigureAwait(false)) + { + newRoot = EnsureSystemThreadingUsing(compilationUnit); + } + + var newDocument = document.WithSyntaxRoot(newRoot); + return await Formatter.FormatAsync(newDocument, Formatter.Annotation, cancellationToken: cancellationToken).ConfigureAwait(false); + } + + // Asks the semantic model whether `CancellationToken` already resolves to + // System.Threading.CancellationToken at the method's position. That covers file-level usings, + // same-file global usings, cross-file global usings, and SDK ImplicitUsings — so we skip + // adding a redundant `using System.Threading;` in every case where the compiler already sees it. + private static async Task IsSystemThreadingInScopeAsync( + Document document, + MethodDeclarationSyntax method, + CancellationToken cancellationToken) + { + var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false); + if (semanticModel is null) + { + return false; + } + + var symbols = semanticModel.LookupSymbols(method.SpanStart, name: CancellationTokenTypeName); + foreach (var symbol in symbols) + { + if (symbol is INamedTypeSymbol type + && type.ContainingNamespace?.ToDisplayString() == SystemThreadingNamespace) + { + return true; + } + } + + return false; + } + + private static CompilationUnitSyntax EnsureSystemThreadingUsing(CompilationUnitSyntax compilationUnit) + { + foreach (var usingDirective in compilationUnit.Usings) + { + if (usingDirective.Name?.ToString() == SystemThreadingNamespace) + { + return compilationUnit; + } + } + + var newUsing = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName(SystemThreadingNamespace)) + .WithAdditionalAnnotations(Formatter.Annotation); + + // Insert in sorted position within the System.* group, or append if no System.* group exists. + // Leaves non-System usings undisturbed to respect the file's existing organization. + var insertAt = -1; + for (var i = 0; i < compilationUnit.Usings.Count; i++) + { + var existing = compilationUnit.Usings[i].Name?.ToString(); + if (existing is null || !existing.StartsWith("System", StringComparison.Ordinal)) + { + continue; + } + + if (string.CompareOrdinal(existing, SystemThreadingNamespace) > 0) + { + insertAt = i; + break; + } + + insertAt = i + 1; + } + + if (insertAt == -1) + { + return compilationUnit.AddUsings(newUsing); + } + + return compilationUnit.WithUsings(compilationUnit.Usings.Insert(insertAt, newUsing)); + } +} diff --git a/TUnit.Analyzers.Tests/TimeoutCancellationTokenCodeFixProviderTests.cs b/TUnit.Analyzers.Tests/TimeoutCancellationTokenCodeFixProviderTests.cs new file mode 100644 index 0000000000..b1b7024a4d --- /dev/null +++ b/TUnit.Analyzers.Tests/TimeoutCancellationTokenCodeFixProviderTests.cs @@ -0,0 +1,330 @@ +using Verifier = TUnit.Analyzers.Tests.Verifiers.CSharpCodeFixVerifier< + TUnit.Analyzers.TimeoutCancellationTokenAnalyzer, + TUnit.Analyzers.CodeFixers.TimeoutCancellationTokenCodeFixProvider>; + +namespace TUnit.Analyzers.Tests; + +public class TimeoutCancellationTokenCodeFixProviderTests +{ + [Test] + public async Task Adds_CancellationToken_Parameter() + { + await Verifier.VerifyCodeFixAsync( + """ + using TUnit.Core; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public async Task {|#0:MyTest|}() + { + await Task.Yield(); + } + } + """, + Verifier.Diagnostic(Rules.MissingTimeoutCancellationTokenAttributes).WithLocation(0), + """ + using TUnit.Core; + using System.Threading; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public async Task MyTest(CancellationToken cancellationToken) + { + await Task.Yield(); + } + } + """, + test => test.CodeActionEquivalenceKey = "AddCancellationToken"); + } + + [Test] + public async Task Adds_CancellationToken_Parameter_With_ThrowIfCancellationRequested() + { + await Verifier.VerifyCodeFixAsync( + """ + using TUnit.Core; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public async Task {|#0:MyTest|}() + { + await Task.Yield(); + } + } + """, + Verifier.Diagnostic(Rules.MissingTimeoutCancellationTokenAttributes).WithLocation(0), + """ + using TUnit.Core; + using System.Threading; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public async Task MyTest(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + await Task.Yield(); + } + } + """, + test => test.CodeActionEquivalenceKey = "AddCancellationTokenWithThrow"); + } + + [Test] + public async Task Adds_CancellationToken_Parameter_As_Discard() + { + await Verifier.VerifyCodeFixAsync( + """ + using TUnit.Core; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public async Task {|#0:MyTest|}() + { + await Task.Yield(); + } + } + """, + Verifier.Diagnostic(Rules.MissingTimeoutCancellationTokenAttributes).WithLocation(0), + """ + using TUnit.Core; + using System.Threading; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public async Task MyTest(CancellationToken cancellationToken) + { + _ = cancellationToken; + await Task.Yield(); + } + } + """, + test => test.CodeActionEquivalenceKey = "AddCancellationTokenAsDiscard"); + } + + [Test] + public async Task Does_Not_Duplicate_Existing_System_Threading_Using() + { + await Verifier.VerifyCodeFixAsync( + """ + using TUnit.Core; + using System.Threading; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public async Task {|#0:MyTest|}() + { + await Task.Yield(); + } + } + """, + Verifier.Diagnostic(Rules.MissingTimeoutCancellationTokenAttributes).WithLocation(0), + """ + using TUnit.Core; + using System.Threading; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public async Task MyTest(CancellationToken cancellationToken) + { + await Task.Yield(); + } + } + """, + test => test.CodeActionEquivalenceKey = "AddCancellationToken"); + } + + [Test] + public async Task Appends_CancellationToken_After_Existing_Parameters() + { + await Verifier.VerifyCodeFixAsync( + """ + using TUnit.Core; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + [Arguments(1, "hello")] + public async Task MyTest(int value, string {|#0:text|}) + { + await Task.Yield(); + } + } + """, + Verifier.Diagnostic(Rules.MissingTimeoutCancellationTokenAttributes).WithLocation(0), + """ + using TUnit.Core; + using System.Threading; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + [Arguments(1, "hello")] + public async Task MyTest(int value, string text, CancellationToken cancellationToken) + { + await Task.Yield(); + } + } + """, + test => test.CodeActionEquivalenceKey = "AddCancellationToken"); + } + + [Test] + public async Task Does_Not_Add_Using_When_System_Threading_Is_Global_Using_In_Other_File() + { + // A cross-file `global using System.Threading;` (e.g. auto-generated _GlobalUsings.g.cs from + // ImplicitUsings, or hand-rolled) already makes CancellationToken resolvable — adding + // `using System.Threading;` to the target file would be redundant. + const string globalUsings = "global using System.Threading;\n"; + + const string source = """ + using TUnit.Core; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public async Task {|#0:MyTest|}() + { + await Task.Yield(); + } + } + """; + + const string fixedSource = """ + using TUnit.Core; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public async Task MyTest(CancellationToken cancellationToken) + { + await Task.Yield(); + } + } + """; + + await Verifier.VerifyCodeFixAsync( + source, + Verifier.Diagnostic(Rules.MissingTimeoutCancellationTokenAttributes).WithLocation(0), + fixedSource, + test => + { + test.CodeActionEquivalenceKey = "AddCancellationToken"; + test.TestState.Sources.Add(globalUsings); + test.FixedState.Sources.Add(globalUsings); + }); + } + + [Test] + public async Task Inserts_System_Threading_Grouped_With_Adjacent_System_Using_When_Non_System_Interspersed() + { + // Locks in behaviour for the contrived case where non-System usings are interspersed between + // System.* usings: the new System.Threading sorts into the System group such that its nearer + // neighbour is also System.* (System.Threading.Tasks on the right), rather than being appended + // at the end of the file or landing with both neighbours non-System. + await Verifier.VerifyCodeFixAsync( + """ + using System.Text; + using Xunit; + using TUnit.Core; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public async Task {|#0:MyTest|}() + { + await Task.Yield(); + } + } + + namespace Xunit { internal class Dummy {} } + """, + Verifier.Diagnostic(Rules.MissingTimeoutCancellationTokenAttributes).WithLocation(0), + """ + using System.Text; + using Xunit; + using TUnit.Core; + using System.Threading; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public async Task MyTest(CancellationToken cancellationToken) + { + await Task.Yield(); + } + } + + namespace Xunit { internal class Dummy {} } + """, + test => test.CodeActionEquivalenceKey = "AddCancellationToken"); + } + + [Test] + public async Task Expression_Bodied_Method_Only_Offers_Bare_Parameter_Action() + { + // Body-modifying variants would silently no-op on an expression-bodied method, + // so only the bare-parameter action is registered. + await Verifier.VerifyCodeFixAsync( + """ + using TUnit.Core; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public Task {|#0:MyTest|}() => Task.CompletedTask; + } + """, + Verifier.Diagnostic(Rules.MissingTimeoutCancellationTokenAttributes).WithLocation(0), + """ + using TUnit.Core; + using System.Threading; + using System.Threading.Tasks; + + public class MyClass + { + [Test] + [Timeout(1000)] + public Task MyTest(CancellationToken cancellationToken) => Task.CompletedTask; + } + """, + test => test.CodeActionEquivalenceKey = "AddCancellationToken"); + } +}