diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems b/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems
index b019f6d49..879399168 100644
--- a/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems
@@ -55,6 +55,7 @@
+
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/MethodDeclarationSyntaxExtensions.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/MethodDeclarationSyntaxExtensions.cs
new file mode 100644
index 000000000..9ce12e785
--- /dev/null
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/MethodDeclarationSyntaxExtensions.cs
@@ -0,0 +1,33 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+
+namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions;
+
+///
+/// Extension methods for the type.
+///
+internal static class MethodDeclarationSyntaxExtensions
+{
+ ///
+ /// Checks whether a given has or could potentially have any attribute lists.
+ ///
+ /// The input to check.
+ /// Whether has or potentially has any attribute lists.
+ public static bool HasOrPotentiallyHasAttributeLists(this MethodDeclarationSyntax methodDeclaration)
+ {
+ // If the declaration has any attribute lists, there's nothing left to do
+ if (methodDeclaration.AttributeLists.Count > 0)
+ {
+ return true;
+ }
+
+ // If there are no attributes, check whether the method declaration has the partial keyword. If it
+ // does, there could potentially be attribute lists on the other partial definition/implementation.
+ return methodDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword);
+ }
+}
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/SyntaxNodeExtensions.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/SyntaxNodeExtensions.cs
index 06fb95cbd..52b7ccbfc 100644
--- a/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/SyntaxNodeExtensions.cs
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/SyntaxNodeExtensions.cs
@@ -26,8 +26,7 @@ internal static class SyntaxNodeExtensions
public static bool IsFirstSyntaxDeclarationForSymbol(this SyntaxNode syntaxNode, ISymbol symbol)
{
return
- symbol.DeclaringSyntaxReferences.Length > 0 &&
- symbol.DeclaringSyntaxReferences[0] is SyntaxReference syntaxReference &&
+ symbol.DeclaringSyntaxReferences is [SyntaxReference syntaxReference, ..] &&
syntaxReference.SyntaxTree == syntaxNode.SyntaxTree &&
syntaxReference.Span == syntaxNode.Span;
}
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs
index 33b2a752e..c3dd8cc64 100644
--- a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs
@@ -430,6 +430,15 @@ private static bool IsCommandDefinitionUnique(IMethodSymbol methodSymbol, in Imm
return true;
}
+ // If the two method symbols are partial and either is the implementation of the other one, this is allowed
+ if ((methodSymbol is { IsPartialDefinition: true, PartialImplementationPart: { } partialImplementation } &&
+ SymbolEqualityComparer.Default.Equals(otherSymbol, partialImplementation)) ||
+ (otherSymbol is { IsPartialDefinition: true, PartialImplementationPart: { } otherPartialImplementation } &&
+ SymbolEqualityComparer.Default.Equals(methodSymbol, otherPartialImplementation)))
+ {
+ continue;
+ }
+
diagnostics.Add(
MultipleRelayCommandMethodOverloadsError,
methodSymbol,
@@ -952,12 +961,24 @@ private static void GatherForwardedAttributes(
using ImmutableArrayBuilder fieldAttributesInfo = ImmutableArrayBuilder.Rent();
using ImmutableArrayBuilder propertyAttributesInfo = ImmutableArrayBuilder.Rent();
- foreach (SyntaxReference syntaxReference in methodSymbol.DeclaringSyntaxReferences)
+ static void GatherForwardedAttributes(
+ IMethodSymbol methodSymbol,
+ SemanticModel semanticModel,
+ CancellationToken token,
+ in ImmutableArrayBuilder diagnostics,
+ in ImmutableArrayBuilder fieldAttributesInfo,
+ in ImmutableArrayBuilder propertyAttributesInfo)
{
+ // Get the single syntax reference for the input method symbol (there should be only one)
+ if (methodSymbol.DeclaringSyntaxReferences is not [SyntaxReference syntaxReference])
+ {
+ return;
+ }
+
// Try to get the target method declaration syntax node
if (syntaxReference.GetSyntax(token) is not MethodDeclarationSyntax methodDeclaration)
{
- continue;
+ return;
}
// Gather explicit forwarded attributes info
@@ -998,6 +1019,22 @@ private static void GatherForwardedAttributes(
}
}
+ // If the method is a partial definition, also gather attributes from the implementation part
+ if (methodSymbol is { IsPartialDefinition: true } or { PartialDefinitionPart: not null })
+ {
+ IMethodSymbol partialDefinition = methodSymbol.PartialDefinitionPart ?? methodSymbol;
+ IMethodSymbol partialImplementation = methodSymbol.PartialImplementationPart ?? methodSymbol;
+
+ // We always give priority to the partial definition, to ensure a predictable and testable ordering
+ GatherForwardedAttributes(partialDefinition, semanticModel, token, in diagnostics, in fieldAttributesInfo, in propertyAttributesInfo);
+ GatherForwardedAttributes(partialImplementation, semanticModel, token, in diagnostics, in fieldAttributesInfo, in propertyAttributesInfo);
+ }
+ else
+ {
+ // If the method is not a partial definition/implementation, just gather attributes from the method with no modifications
+ GatherForwardedAttributes(methodSymbol, semanticModel, token, in diagnostics, in fieldAttributesInfo, in propertyAttributesInfo);
+ }
+
fieldAttributes = fieldAttributesInfo.ToImmutable();
propertyAttributes = propertyAttributesInfo.ToImmutable();
}
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs
index e18f83b9c..ddf24bf05 100644
--- a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs
@@ -27,7 +27,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
context.SyntaxProvider
.ForAttributeWithMetadataName(
"CommunityToolkit.Mvvm.Input.RelayCommandAttribute",
- static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax, AttributeLists.Count: > 0 },
+ static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax } methodDeclaration && methodDeclaration.HasOrPotentiallyHasAttributeLists(),
static (context, token) =>
{
if (!context.SemanticModel.Compilation.HasLanguageVersionAtLeastEqualTo(LanguageVersion.CSharp8))
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Polyfills/SyntaxValueProviderExtensions.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Polyfills/SyntaxValueProviderExtensions.cs
index 44a949823..402a421d6 100644
--- a/src/CommunityToolkit.Mvvm.SourceGenerators/Polyfills/SyntaxValueProviderExtensions.cs
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Polyfills/SyntaxValueProviderExtensions.cs
@@ -59,6 +59,15 @@ public static IncrementalValuesProvider ForAttributeWithMetadataName(
return null;
}
+ // Edge case: if the symbol is a partial method, skip the implementation part and only process the partial method
+ // definition. This is needed because attributes will be reported as available on both the definition and the
+ // implementation part. To avoid generating duplicate files, we only give priority to the definition part.
+ // On Roslyn 4.3+, ForAttributeWithMetadataName will already only return the symbol the attribute was located on.
+ if (symbol is IMethodSymbol { IsPartialDefinition: false, PartialDefinitionPart: not null })
+ {
+ return null;
+ }
+
// Create the GeneratorAttributeSyntaxContext value to pass to the input transform. The attributes array
// will only ever have a single value, but that's fine with the attributes the various generators look for.
GeneratorAttributeSyntaxContext syntaxContext = new(
diff --git a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs
index fa76a7126..158eefbba 100644
--- a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs
+++ b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs
@@ -268,6 +268,169 @@ partial class MyViewModel
VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.MyViewModel.Test.g.cs", result));
}
+ // See https://github.com/CommunityToolkit/dotnet/issues/632
+ [TestMethod]
+ public void RelayCommandMethodWithPartialDeclarations_TriggersCorrectly()
+ {
+ string source = """
+ using CommunityToolkit.Mvvm.Input;
+
+ #nullable enable
+
+ namespace MyApp;
+
+ partial class MyViewModel
+ {
+ [RelayCommand]
+ private partial void Test1()
+ {
+ }
+
+ private partial void Test1();
+
+ private partial void Test2()
+ {
+ }
+
+ [RelayCommand]
+ private partial void Test2();
+ }
+ """;
+
+ string result1 = """
+ //
+ #pragma warning disable
+ #nullable enable
+ namespace MyApp
+ {
+ partial class MyViewModel
+ {
+ /// The backing field for .
+ [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
+ private global::CommunityToolkit.Mvvm.Input.RelayCommand? test1Command;
+ /// Gets an instance wrapping .
+ [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ public global::CommunityToolkit.Mvvm.Input.IRelayCommand Test1Command => test1Command ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(Test1));
+ }
+ }
+ """;
+
+ string result2 = """
+ //
+ #pragma warning disable
+ #nullable enable
+ namespace MyApp
+ {
+ partial class MyViewModel
+ {
+ /// The backing field for .
+ [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
+ private global::CommunityToolkit.Mvvm.Input.RelayCommand? test2Command;
+ /// Gets an instance wrapping .
+ [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ public global::CommunityToolkit.Mvvm.Input.IRelayCommand Test2Command => test2Command ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(Test2));
+ }
+ }
+ """;
+
+ VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.MyViewModel.Test1.g.cs", result1), ("MyApp.MyViewModel.Test2.g.cs", result2));
+ }
+
+ // See https://github.com/CommunityToolkit/dotnet/issues/632
+ [TestMethod]
+ public void RelayCommandMethodWithForwardedAttributesOverPartialDeclarations_MergesAttributes()
+ {
+ string source = """
+ using CommunityToolkit.Mvvm.Input;
+
+ #nullable enable
+
+ namespace MyApp;
+
+ partial class MyViewModel
+ {
+ [RelayCommand]
+ [field: Value(0)]
+ [property: Value(1)]
+ private partial void Test1()
+ {
+ }
+
+ [field: Value(2)]
+ [property: Value(3)]
+ private partial void Test1();
+
+ [field: Value(0)]
+ [property: Value(1)]
+ private partial void Test2()
+ {
+ }
+
+ [RelayCommand]
+ [field: Value(2)]
+ [property: Value(3)]
+ private partial void Test2();
+ }
+
+ public class ValueAttribute : Attribute
+ {
+ public ValueAttribute(object value)
+ {
+ }
+ }
+ """;
+
+ string result1 = """
+ //
+ #pragma warning disable
+ #nullable enable
+ namespace MyApp
+ {
+ partial class MyViewModel
+ {
+ /// The backing field for .
+ [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
+ [global::MyApp.ValueAttribute(2)]
+ [global::MyApp.ValueAttribute(0)]
+ private global::CommunityToolkit.Mvvm.Input.RelayCommand? test1Command;
+ /// Gets an instance wrapping .
+ [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::MyApp.ValueAttribute(3)]
+ [global::MyApp.ValueAttribute(1)]
+ public global::CommunityToolkit.Mvvm.Input.IRelayCommand Test1Command => test1Command ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(Test1));
+ }
+ }
+ """;
+
+ string result2 = """
+ //
+ #pragma warning disable
+ #nullable enable
+ namespace MyApp
+ {
+ partial class MyViewModel
+ {
+ /// The backing field for .
+ [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
+ [global::MyApp.ValueAttribute(2)]
+ [global::MyApp.ValueAttribute(0)]
+ private global::CommunityToolkit.Mvvm.Input.RelayCommand? test2Command;
+ /// Gets an instance wrapping .
+ [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ [global::MyApp.ValueAttribute(3)]
+ [global::MyApp.ValueAttribute(1)]
+ public global::CommunityToolkit.Mvvm.Input.IRelayCommand Test2Command => test2Command ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(Test2));
+ }
+ }
+ """;
+
+ VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.MyViewModel.Test1.g.cs", result1), ("MyApp.MyViewModel.Test2.g.cs", result2));
+ }
+
[TestMethod]
public void ObservablePropertyWithinGenericAndNestedTypes()
{
diff --git a/tests/CommunityToolkit.Mvvm.UnitTests/Test_RelayCommandAttribute.cs b/tests/CommunityToolkit.Mvvm.UnitTests/Test_RelayCommandAttribute.cs
index 3114cf1e5..ecf4ff05d 100644
--- a/tests/CommunityToolkit.Mvvm.UnitTests/Test_RelayCommandAttribute.cs
+++ b/tests/CommunityToolkit.Mvvm.UnitTests/Test_RelayCommandAttribute.cs
@@ -659,6 +659,42 @@ static void ValidateTestAttribute(TestValidationAttribute testAttribute)
Assert.AreEqual(testAttribute2.Animal, (Test_ObservablePropertyAttribute.Animal)67);
}
+ // See https://github.com/CommunityToolkit/dotnet/issues/632
+ [TestMethod]
+ public void Test_RelayCommandAttribute_WithPartialCommandMethodDefinitions()
+ {
+ ModelWithPartialCommandMethods model = new();
+
+ Assert.IsInstanceOfType(model.FooCommand);
+ Assert.IsInstanceOfType>(model.BarCommand);
+ Assert.IsInstanceOfType(model.BazCommand);
+ Assert.IsInstanceOfType(model.FooBarCommand);
+
+ FieldInfo bazField = typeof(ModelWithPartialCommandMethods).GetField("bazCommand", BindingFlags.Instance | BindingFlags.NonPublic)!;
+
+ Assert.IsNotNull(bazField.GetCustomAttribute());
+ Assert.IsNotNull(bazField.GetCustomAttribute());
+ Assert.AreEqual(bazField.GetCustomAttribute()!.Length, 1);
+
+ PropertyInfo bazProperty = typeof(ModelWithPartialCommandMethods).GetProperty("BazCommand")!;
+
+ Assert.IsNotNull(bazProperty.GetCustomAttribute());
+ Assert.AreEqual(bazProperty.GetCustomAttribute()!.Length, 2);
+ Assert.IsNotNull(bazProperty.GetCustomAttribute());
+
+ FieldInfo fooBarField = typeof(ModelWithPartialCommandMethods).GetField("fooBarCommand", BindingFlags.Instance | BindingFlags.NonPublic)!;
+
+ Assert.IsNotNull(fooBarField.GetCustomAttribute());
+ Assert.IsNotNull(fooBarField.GetCustomAttribute());
+ Assert.AreEqual(fooBarField.GetCustomAttribute()!.Length, 1);
+
+ PropertyInfo fooBarProperty = typeof(ModelWithPartialCommandMethods).GetProperty("FooBarCommand")!;
+
+ Assert.IsNotNull(fooBarProperty.GetCustomAttribute());
+ Assert.AreEqual(fooBarProperty.GetCustomAttribute()!.Length, 2);
+ Assert.IsNotNull(fooBarProperty.GetCustomAttribute());
+ }
+
#region Region
public class Region
{
@@ -1202,4 +1238,44 @@ public TestValidationAttribute(object? o, Type t, bool flag, double d, string[]
public Test_ObservablePropertyAttribute.Animal Animal { get; set; }
}
+
+ public partial class ModelWithPartialCommandMethods
+ {
+ [RelayCommand]
+ private partial void Foo();
+
+ private partial void Foo()
+ {
+ }
+
+ private partial void Bar(string name);
+
+ [RelayCommand]
+ private partial void Bar(string name)
+ {
+ }
+
+ [RelayCommand]
+ [field: Required]
+ [property: MinLength(2)]
+ partial void Baz();
+
+ [field: MinLength(1)]
+ [property: XmlIgnore]
+ partial void Baz()
+ {
+ }
+
+ [field: Required]
+ [property: MinLength(2)]
+ private partial Task FooBarAsync();
+
+ [RelayCommand]
+ [field: MinLength(1)]
+ [property: XmlIgnore]
+ private partial Task FooBarAsync()
+ {
+ return Task.CompletedTask;
+ }
+ }
}