diff --git a/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs b/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs
new file mode 100644
index 000000000..9b19fc954
--- /dev/null
+++ b/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs
@@ -0,0 +1,103 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Immutable;
+using System.Composition;
+using System.Threading.Tasks;
+using CommunityToolkit.Mvvm.SourceGenerators;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CodeActions;
+using Microsoft.CodeAnalysis.CodeFixes;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.Editing;
+using Microsoft.CodeAnalysis.Text;
+using static CommunityToolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+
+namespace CommunityToolkit.Mvvm.CodeFixers;
+
+///
+/// A code fixer that automatically updates types using [ObservableObject] or [INotifyPropertyChanged]
+/// that have no base type to inherit from ObservableObject instead.
+///
+[ExportCodeFixProvider(LanguageNames.CSharp)]
+[Shared]
+public sealed class ClassUsingAttributeInsteadOfInheritanceCodeFixer : CodeFixProvider
+{
+ ///
+ public override ImmutableArray FixableDiagnosticIds { get; } = ImmutableArray.Create(
+ InheritFromObservableObjectInsteadOfUsingINotifyPropertyChangedAttributeId,
+ InheritFromObservableObjectInsteadOfUsingObservableObjectAttributeId);
+
+ ///
+ public override FixAllProvider? GetFixAllProvider()
+ {
+ return WellKnownFixAllProviders.BatchFixer;
+ }
+
+ ///
+ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
+ {
+ Diagnostic diagnostic = context.Diagnostics[0];
+ TextSpan diagnosticSpan = diagnostic.Location.SourceSpan;
+
+ // Retrieve the property passed by the analyzer
+ if (diagnostic.Properties[ClassUsingAttributeInsteadOfInheritanceAnalyzer.TypeNameKey] is not string typeName ||
+ diagnostic.Properties[ClassUsingAttributeInsteadOfInheritanceAnalyzer.AttributeTypeNameKey] is not string attributeTypeName)
+ {
+ return;
+ }
+
+ SyntaxNode? root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);
+
+ // Get the class declaration from the target diagnostic
+ if (root!.FindNode(diagnosticSpan) is ClassDeclarationSyntax { Identifier.Text: string identifierName } classDeclaration &&
+ identifierName == typeName)
+ {
+ // Register the code fix to update the class declaration to inherit from ObservableObject instead
+ context.RegisterCodeFix(
+ CodeAction.Create(
+ title: "Inherit from ObservableObject",
+ createChangedDocument: token => UpdateReference(context.Document, root, classDeclaration, attributeTypeName),
+ equivalenceKey: "Inherit from ObservableObject"),
+ diagnostic);
+
+ return;
+ }
+ }
+
+ ///
+ /// Applies the code fix to a target class declaration and returns an updated document.
+ ///
+ /// The original document being fixed.
+ /// The original tree root belonging to the current document.
+ /// The to update.
+ /// The name of the attribute that should be removed.
+ /// An updated document with the applied code fix, and inheriting from ObservableObject.
+ private static Task UpdateReference(Document document, SyntaxNode root, ClassDeclarationSyntax classDeclaration, string attributeTypeName)
+ {
+ // Insert ObservableObject always in first position in the base list. The type might have
+ // some interfaces in the base list, so we just copy them back after ObservableObject.
+ SyntaxGenerator generator = SyntaxGenerator.GetGenerator(document);
+ ClassDeclarationSyntax updatedClassDeclaration = (ClassDeclarationSyntax)generator.AddBaseType(classDeclaration, IdentifierName("ObservableObject"));
+
+ // Find the attribute list and attribute to remove
+ foreach (AttributeListSyntax attributeList in updatedClassDeclaration.AttributeLists)
+ {
+ foreach (AttributeSyntax attribute in attributeList.Attributes)
+ {
+ if (attribute.Name is IdentifierNameSyntax { Identifier.Text: string identifierName } &&
+ (identifierName == attributeTypeName || (identifierName + "Attribute") == attributeTypeName))
+ {
+ // We found the attribute to remove and the list to update
+ updatedClassDeclaration = (ClassDeclarationSyntax)generator.RemoveNode(updatedClassDeclaration, attribute);
+
+ break;
+ }
+ }
+ }
+
+ return Task.FromResult(document.WithSyntaxRoot(root.ReplaceNode(classDeclaration, updatedClassDeclaration)));
+ }
+}
diff --git a/src/CommunityToolkit.Mvvm.CodeFixers/FieldReferenceForObservablePropertyFieldFixer.cs b/src/CommunityToolkit.Mvvm.CodeFixers/FieldReferenceForObservablePropertyFieldCodeFixer.cs
similarity index 96%
rename from src/CommunityToolkit.Mvvm.CodeFixers/FieldReferenceForObservablePropertyFieldFixer.cs
rename to src/CommunityToolkit.Mvvm.CodeFixers/FieldReferenceForObservablePropertyFieldCodeFixer.cs
index 5a3a186db..3788192b3 100644
--- a/src/CommunityToolkit.Mvvm.CodeFixers/FieldReferenceForObservablePropertyFieldFixer.cs
+++ b/src/CommunityToolkit.Mvvm.CodeFixers/FieldReferenceForObservablePropertyFieldCodeFixer.cs
@@ -7,13 +7,13 @@
using System.Threading;
using System.Threading.Tasks;
using CommunityToolkit.Mvvm.SourceGenerators;
-using CommunityToolkit.Mvvm.SourceGenerators.Diagnostics;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
+using static CommunityToolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors;
namespace CommunityToolkit.Mvvm.CodeFixers;
@@ -25,7 +25,7 @@ namespace CommunityToolkit.Mvvm.CodeFixers;
public sealed class FieldReferenceForObservablePropertyFieldCodeFixer : CodeFixProvider
{
///
- public override ImmutableArray FixableDiagnosticIds { get; } = ImmutableArray.Create(DiagnosticDescriptors.FieldReferenceForObservablePropertyFieldId);
+ public override ImmutableArray FixableDiagnosticIds { get; } = ImmutableArray.Create(FieldReferenceForObservablePropertyFieldId);
///
public override FixAllProvider? GetFixAllProvider()
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/Analyzers/ClassUsingAttributeInsteadOfInheritanceAnalyzer.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/Analyzers/ClassUsingAttributeInsteadOfInheritanceAnalyzer.cs
index 6706d9cca..1dbacc7b9 100644
--- a/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/Analyzers/ClassUsingAttributeInsteadOfInheritanceAnalyzer.cs
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/Analyzers/ClassUsingAttributeInsteadOfInheritanceAnalyzer.cs
@@ -17,6 +17,16 @@ namespace CommunityToolkit.Mvvm.SourceGenerators;
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public sealed class ClassUsingAttributeInsteadOfInheritanceAnalyzer : DiagnosticAnalyzer
{
+ ///
+ /// The key for the name of the target type to update.
+ ///
+ internal const string TypeNameKey = "TypeName";
+
+ ///
+ /// The key for the name of the attribute that was found and should be removed.
+ ///
+ internal const string AttributeTypeNameKey = "AttributeTypeName";
+
///
/// The mapping of target attributes that will trigger the analyzer.
///
@@ -67,7 +77,13 @@ public override void Initialize(AnalysisContext context)
if (classSymbol.BaseType is { SpecialType: SpecialType.System_Object })
{
// This type is using the attribute when it could just inherit from ObservableObject, which is preferred
- context.ReportDiagnostic(Diagnostic.Create(GeneratorAttributeNamesToDiagnosticsMap[attributeClass.Name], context.Symbol.Locations.FirstOrDefault(), context.Symbol));
+ context.ReportDiagnostic(Diagnostic.Create(
+ GeneratorAttributeNamesToDiagnosticsMap[attributeClass.Name],
+ context.Symbol.Locations.FirstOrDefault(),
+ ImmutableDictionary.Create()
+ .Add(TypeNameKey, classSymbol.Name)
+ .Add(AttributeTypeNameKey, attributeName),
+ context.Symbol));
}
}
}
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticDescriptors.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticDescriptors.cs
index 05efb3e48..993f959c1 100644
--- a/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticDescriptors.cs
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticDescriptors.cs
@@ -14,6 +14,16 @@ namespace CommunityToolkit.Mvvm.SourceGenerators.Diagnostics;
///
internal static class DiagnosticDescriptors
{
+ ///
+ /// The diagnostic id for .
+ ///
+ public const string InheritFromObservableObjectInsteadOfUsingINotifyPropertyChangedAttributeId = "MVVMTK0032";
+
+ ///
+ /// The diagnostic id for .
+ ///
+ public const string InheritFromObservableObjectInsteadOfUsingObservableObjectAttributeId = "MVVMTK0033";
+
///
/// The diagnostic id for .
///
@@ -519,7 +529,7 @@ internal static class DiagnosticDescriptors
///
///
public static readonly DiagnosticDescriptor InheritFromObservableObjectInsteadOfUsingINotifyPropertyChangedAttributeWarning = new DiagnosticDescriptor(
- id: "MVVMTK0032",
+ id: InheritFromObservableObjectInsteadOfUsingINotifyPropertyChangedAttributeId,
title: "Inherit from ObservableObject instead of using [INotifyPropertyChanged]",
messageFormat: "The type {0} is using the [INotifyPropertyChanged] attribute while having no base type, and it should instead inherit from ObservableObject",
category: typeof(INotifyPropertyChangedGenerator).FullName,
@@ -537,7 +547,7 @@ internal static class DiagnosticDescriptors
///
///
public static readonly DiagnosticDescriptor InheritFromObservableObjectInsteadOfUsingObservableObjectAttributeWarning = new DiagnosticDescriptor(
- id: "MVVMTK0033",
+ id: InheritFromObservableObjectInsteadOfUsingObservableObjectAttributeId,
title: "Inherit from ObservableObject instead of using [ObservableObject]",
messageFormat: "The type {0} is using the [ObservableObject] attribute while having no base type, and it should instead inherit from ObservableObject",
category: typeof(ObservableObjectGenerator).FullName,
diff --git a/tests/CommunityToolkit.Mvvm.SourceGenerators.Roslyn401.UnitTests/Test_ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs b/tests/CommunityToolkit.Mvvm.SourceGenerators.Roslyn401.UnitTests/Test_ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs
new file mode 100644
index 000000000..17a4a2c09
--- /dev/null
+++ b/tests/CommunityToolkit.Mvvm.SourceGenerators.Roslyn401.UnitTests/Test_ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs
@@ -0,0 +1,327 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Threading.Tasks;
+using CommunityToolkit.Mvvm.ComponentModel;
+using Microsoft.CodeAnalysis.Testing;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using CSharpCodeFixTest = Microsoft.CodeAnalysis.CSharp.Testing.CSharpCodeFixTest<
+ CommunityToolkit.Mvvm.SourceGenerators.ClassUsingAttributeInsteadOfInheritanceAnalyzer,
+ CommunityToolkit.Mvvm.CodeFixers.ClassUsingAttributeInsteadOfInheritanceCodeFixer,
+ Microsoft.CodeAnalysis.Testing.Verifiers.MSTestVerifier>;
+using CSharpCodeFixVerifier = Microsoft.CodeAnalysis.CSharp.Testing.CSharpCodeFixVerifier<
+ CommunityToolkit.Mvvm.SourceGenerators.ClassUsingAttributeInsteadOfInheritanceAnalyzer,
+ CommunityToolkit.Mvvm.CodeFixers.ClassUsingAttributeInsteadOfInheritanceCodeFixer,
+ Microsoft.CodeAnalysis.Testing.Verifiers.MSTestVerifier>;
+
+namespace CommunityToolkit.Mvvm.SourceGenerators.Roslyn401.UnitTests;
+
+[TestClass]
+public class ClassUsingAttributeInsteadOfInheritanceCodeFixer
+{
+ [TestMethod]
+ [DataRow("INotifyPropertyChanged", "MVVMTK0032")]
+ [DataRow("ObservableObject", "MVVMTK0033")]
+ public async Task SingleAttributeList(string attributeTypeName, string diagnosticId)
+ {
+ string original = $$"""
+ using CommunityToolkit.Mvvm.ComponentModel;
+
+ // This is some trivia
+ [{{attributeTypeName}}]
+ class C
+ {
+ }
+ """;
+
+ string @fixed = """
+ using CommunityToolkit.Mvvm.ComponentModel;
+
+ // This is some trivia
+ class C : ObservableObject
+ {
+ }
+ """;
+
+ CSharpCodeFixTest test = new()
+ {
+ TestCode = original,
+ FixedCode = @fixed,
+ ReferenceAssemblies = ReferenceAssemblies.Net.Net60
+ };
+
+ test.TestState.AdditionalReferences.Add(typeof(ObservableObject).Assembly);
+ test.ExpectedDiagnostics.AddRange(new[]
+ {
+ // /0/Test0.cs(5,15): warning : The type C is using the attribute while having no base type, and it should instead inherit from ObservableObject
+ CSharpCodeFixVerifier.Diagnostic(diagnosticId).WithSpan(5, 7, 5, 8).WithArguments("C")
+ });
+
+ await test.RunAsync();
+ }
+
+ [TestMethod]
+ [DataRow("INotifyPropertyChanged", "MVVMTK0032")]
+ [DataRow("ObservableObject", "MVVMTK0033")]
+ public async Task SingleAttributeList_WithOtherInterface(string attributeTypeName, string diagnosticId)
+ {
+ string original = $$"""
+ using System;
+ using CommunityToolkit.Mvvm.ComponentModel;
+
+ // This is some trivia
+ [{{attributeTypeName}}]
+ class C : IDisposable
+ {
+ public void Dispose()
+ {
+ }
+ }
+ """;
+
+ string @fixed = """
+ using System;
+ using CommunityToolkit.Mvvm.ComponentModel;
+
+ // This is some trivia
+ class C : ObservableObject, IDisposable
+ {
+ public void Dispose()
+ {
+ }
+ }
+ """;
+
+ CSharpCodeFixTest test = new()
+ {
+ TestCode = original,
+ FixedCode = @fixed,
+ ReferenceAssemblies = ReferenceAssemblies.Net.Net60
+ };
+
+ test.TestState.AdditionalReferences.Add(typeof(ObservableObject).Assembly);
+ test.ExpectedDiagnostics.AddRange(new[]
+ {
+ // /0/Test0.cs(5,15): warning : The type C is using the attribute while having no base type, and it should instead inherit from ObservableObject
+ CSharpCodeFixVerifier.Diagnostic(diagnosticId).WithSpan(6, 7, 6, 8).WithArguments("C")
+ });
+
+ await test.RunAsync();
+ }
+
+ [TestMethod]
+ [DataRow("INotifyPropertyChanged", "MVVMTK0032")]
+ [DataRow("ObservableObject", "MVVMTK0033")]
+ public async Task MultipleAttributeLists_OneBeforeTarget(string attributeTypeName, string diagnosticId)
+ {
+ string original = $$"""
+ using System;
+ using CommunityToolkit.Mvvm.ComponentModel;
+
+ // This is some trivia
+ [Test]
+ [{{attributeTypeName}}]
+ class C
+ {
+ }
+
+ class TestAttribute : Attribute
+ {
+ }
+ """;
+
+ string @fixed = """
+ using System;
+ using CommunityToolkit.Mvvm.ComponentModel;
+
+ // This is some trivia
+ [Test]
+ class C : ObservableObject
+ {
+ }
+
+ class TestAttribute : Attribute
+ {
+ }
+ """;
+
+ CSharpCodeFixTest test = new()
+ {
+ TestCode = original,
+ FixedCode = @fixed,
+ ReferenceAssemblies = ReferenceAssemblies.Net.Net60
+ };
+
+ test.TestState.AdditionalReferences.Add(typeof(ObservableObject).Assembly);
+ test.ExpectedDiagnostics.AddRange(new[]
+ {
+ // /0/Test0.cs(5,15): warning : The type C is using the attribute while having no base type, and it should instead inherit from ObservableObject
+ CSharpCodeFixVerifier.Diagnostic(diagnosticId).WithSpan(7, 7, 7, 8).WithArguments("C")
+ });
+
+ await test.RunAsync();
+ }
+
+ [TestMethod]
+ [DataRow("INotifyPropertyChanged", "MVVMTK0032")]
+ [DataRow("ObservableObject", "MVVMTK0033")]
+ public async Task MultipleAttributeLists_OneAfterTarget(string attributeTypeName, string diagnosticId)
+ {
+ string original = $$"""
+ using System;
+ using CommunityToolkit.Mvvm.ComponentModel;
+
+ // This is some trivia
+ [{{attributeTypeName}}]
+ [Test]
+ class C
+ {
+ }
+
+ class TestAttribute : Attribute
+ {
+ }
+ """;
+
+ string @fixed = """
+ using System;
+ using CommunityToolkit.Mvvm.ComponentModel;
+
+ // This is some trivia
+ [Test]
+ class C : ObservableObject
+ {
+ }
+
+ class TestAttribute : Attribute
+ {
+ }
+ """;
+
+ CSharpCodeFixTest test = new()
+ {
+ TestCode = original,
+ FixedCode = @fixed,
+ ReferenceAssemblies = ReferenceAssemblies.Net.Net60
+ };
+
+ test.TestState.AdditionalReferences.Add(typeof(ObservableObject).Assembly);
+ test.ExpectedDiagnostics.AddRange(new[]
+ {
+ // /0/Test0.cs(5,15): warning : The type C is using the attribute while having no base type, and it should instead inherit from ObservableObject
+ CSharpCodeFixVerifier.Diagnostic(diagnosticId).WithSpan(7, 7, 7, 8).WithArguments("C")
+ });
+
+ await test.RunAsync();
+ }
+
+ [TestMethod]
+ [DataRow("INotifyPropertyChanged", "MVVMTK0032")]
+ [DataRow("ObservableObject", "MVVMTK0033")]
+ public async Task MultipleAttributeLists_OneBeforeAndOneAfterTarget(string attributeTypeName, string diagnosticId)
+ {
+ string original = $$"""
+ using System;
+ using CommunityToolkit.Mvvm.ComponentModel;
+
+ // This is some trivia
+ [Test]
+ [{{attributeTypeName}}]
+ [Test]
+ class C
+ {
+ }
+
+ [AttributeUsage(AttributeTargets.Class, AllowMultiple = true)]
+ class TestAttribute : Attribute
+ {
+ }
+ """;
+
+ string @fixed = """
+ using System;
+ using CommunityToolkit.Mvvm.ComponentModel;
+
+ // This is some trivia
+ [Test]
+ [Test]
+ class C : ObservableObject
+ {
+ }
+
+ [AttributeUsage(AttributeTargets.Class, AllowMultiple = true)]
+ class TestAttribute : Attribute
+ {
+ }
+ """;
+
+ CSharpCodeFixTest test = new()
+ {
+ TestCode = original,
+ FixedCode = @fixed,
+ ReferenceAssemblies = ReferenceAssemblies.Net.Net60
+ };
+
+ test.TestState.AdditionalReferences.Add(typeof(ObservableObject).Assembly);
+ test.ExpectedDiagnostics.AddRange(new[]
+ {
+ // /0/Test0.cs(5,15): warning : The type C is using the attribute while having no base type, and it should instead inherit from ObservableObject
+ CSharpCodeFixVerifier.Diagnostic(diagnosticId).WithSpan(8, 7, 8, 8).WithArguments("C")
+ });
+
+ await test.RunAsync();
+ }
+
+ [TestMethod]
+ [DataRow("INotifyPropertyChanged", "MVVMTK0032")]
+ [DataRow("ObservableObject", "MVVMTK0033")]
+ public async Task MultipleAttributesInAttributeList(string attributeTypeName, string diagnosticId)
+ {
+ string original = $$"""
+ using System;
+ using CommunityToolkit.Mvvm.ComponentModel;
+
+ // This is some trivia
+ [Test, {{attributeTypeName}}]
+ class C
+ {
+ }
+
+ class TestAttribute : Attribute
+ {
+ }
+ """;
+
+ string @fixed = """
+ using System;
+ using CommunityToolkit.Mvvm.ComponentModel;
+
+ // This is some trivia
+ [Test]
+ class C : ObservableObject
+ {
+ }
+
+ class TestAttribute : Attribute
+ {
+ }
+ """;
+
+ CSharpCodeFixTest test = new()
+ {
+ TestCode = original,
+ FixedCode = @fixed,
+ ReferenceAssemblies = ReferenceAssemblies.Net.Net60
+ };
+
+ test.TestState.AdditionalReferences.Add(typeof(ObservableObject).Assembly);
+ test.ExpectedDiagnostics.AddRange(new[]
+ {
+ // /0/Test0.cs(5,15): warning : The type C is using the attribute while having no base type, and it should instead inherit from ObservableObject
+ CSharpCodeFixVerifier.Diagnostic(diagnosticId).WithSpan(6, 7, 6, 8).WithArguments("C")
+ });
+
+ await test.RunAsync();
+ }
+}
diff --git a/tests/CommunityToolkit.Mvvm.SourceGenerators.Roslyn401.UnitTests/Test_FieldReferenceForObservablePropertyFieldFixer.cs b/tests/CommunityToolkit.Mvvm.SourceGenerators.Roslyn401.UnitTests/Test_FieldReferenceForObservablePropertyFieldCodeFixer.cs
similarity index 100%
rename from tests/CommunityToolkit.Mvvm.SourceGenerators.Roslyn401.UnitTests/Test_FieldReferenceForObservablePropertyFieldFixer.cs
rename to tests/CommunityToolkit.Mvvm.SourceGenerators.Roslyn401.UnitTests/Test_FieldReferenceForObservablePropertyFieldCodeFixer.cs