diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems b/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems index b019f6d49..879399168 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems @@ -55,6 +55,7 @@ + diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/MethodDeclarationSyntaxExtensions.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/MethodDeclarationSyntaxExtensions.cs new file mode 100644 index 000000000..9ce12e785 --- /dev/null +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/MethodDeclarationSyntaxExtensions.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions; + +/// +/// Extension methods for the type. +/// +internal static class MethodDeclarationSyntaxExtensions +{ + /// + /// Checks whether a given has or could potentially have any attribute lists. + /// + /// The input to check. + /// Whether has or potentially has any attribute lists. + public static bool HasOrPotentiallyHasAttributeLists(this MethodDeclarationSyntax methodDeclaration) + { + // If the declaration has any attribute lists, there's nothing left to do + if (methodDeclaration.AttributeLists.Count > 0) + { + return true; + } + + // If there are no attributes, check whether the method declaration has the partial keyword. If it + // does, there could potentially be attribute lists on the other partial definition/implementation. + return methodDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword); + } +} diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/SyntaxNodeExtensions.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/SyntaxNodeExtensions.cs index 06fb95cbd..52b7ccbfc 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/SyntaxNodeExtensions.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/SyntaxNodeExtensions.cs @@ -26,8 +26,7 @@ internal static class SyntaxNodeExtensions public static bool IsFirstSyntaxDeclarationForSymbol(this SyntaxNode syntaxNode, ISymbol symbol) { return - symbol.DeclaringSyntaxReferences.Length > 0 && - symbol.DeclaringSyntaxReferences[0] is SyntaxReference syntaxReference && + symbol.DeclaringSyntaxReferences is [SyntaxReference syntaxReference, ..] && syntaxReference.SyntaxTree == syntaxNode.SyntaxTree && syntaxReference.Span == syntaxNode.Span; } diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs index 33b2a752e..c3dd8cc64 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs @@ -430,6 +430,15 @@ private static bool IsCommandDefinitionUnique(IMethodSymbol methodSymbol, in Imm return true; } + // If the two method symbols are partial and either is the implementation of the other one, this is allowed + if ((methodSymbol is { IsPartialDefinition: true, PartialImplementationPart: { } partialImplementation } && + SymbolEqualityComparer.Default.Equals(otherSymbol, partialImplementation)) || + (otherSymbol is { IsPartialDefinition: true, PartialImplementationPart: { } otherPartialImplementation } && + SymbolEqualityComparer.Default.Equals(methodSymbol, otherPartialImplementation))) + { + continue; + } + diagnostics.Add( MultipleRelayCommandMethodOverloadsError, methodSymbol, @@ -952,12 +961,24 @@ private static void GatherForwardedAttributes( using ImmutableArrayBuilder fieldAttributesInfo = ImmutableArrayBuilder.Rent(); using ImmutableArrayBuilder propertyAttributesInfo = ImmutableArrayBuilder.Rent(); - foreach (SyntaxReference syntaxReference in methodSymbol.DeclaringSyntaxReferences) + static void GatherForwardedAttributes( + IMethodSymbol methodSymbol, + SemanticModel semanticModel, + CancellationToken token, + in ImmutableArrayBuilder diagnostics, + in ImmutableArrayBuilder fieldAttributesInfo, + in ImmutableArrayBuilder propertyAttributesInfo) { + // Get the single syntax reference for the input method symbol (there should be only one) + if (methodSymbol.DeclaringSyntaxReferences is not [SyntaxReference syntaxReference]) + { + return; + } + // Try to get the target method declaration syntax node if (syntaxReference.GetSyntax(token) is not MethodDeclarationSyntax methodDeclaration) { - continue; + return; } // Gather explicit forwarded attributes info @@ -998,6 +1019,22 @@ private static void GatherForwardedAttributes( } } + // If the method is a partial definition, also gather attributes from the implementation part + if (methodSymbol is { IsPartialDefinition: true } or { PartialDefinitionPart: not null }) + { + IMethodSymbol partialDefinition = methodSymbol.PartialDefinitionPart ?? methodSymbol; + IMethodSymbol partialImplementation = methodSymbol.PartialImplementationPart ?? methodSymbol; + + // We always give priority to the partial definition, to ensure a predictable and testable ordering + GatherForwardedAttributes(partialDefinition, semanticModel, token, in diagnostics, in fieldAttributesInfo, in propertyAttributesInfo); + GatherForwardedAttributes(partialImplementation, semanticModel, token, in diagnostics, in fieldAttributesInfo, in propertyAttributesInfo); + } + else + { + // If the method is not a partial definition/implementation, just gather attributes from the method with no modifications + GatherForwardedAttributes(methodSymbol, semanticModel, token, in diagnostics, in fieldAttributesInfo, in propertyAttributesInfo); + } + fieldAttributes = fieldAttributesInfo.ToImmutable(); propertyAttributes = propertyAttributesInfo.ToImmutable(); } diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs index e18f83b9c..ddf24bf05 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs @@ -27,7 +27,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.SyntaxProvider .ForAttributeWithMetadataName( "CommunityToolkit.Mvvm.Input.RelayCommandAttribute", - static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax, AttributeLists.Count: > 0 }, + static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax } methodDeclaration && methodDeclaration.HasOrPotentiallyHasAttributeLists(), static (context, token) => { if (!context.SemanticModel.Compilation.HasLanguageVersionAtLeastEqualTo(LanguageVersion.CSharp8)) diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Polyfills/SyntaxValueProviderExtensions.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Polyfills/SyntaxValueProviderExtensions.cs index 44a949823..402a421d6 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/Polyfills/SyntaxValueProviderExtensions.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Polyfills/SyntaxValueProviderExtensions.cs @@ -59,6 +59,15 @@ public static IncrementalValuesProvider ForAttributeWithMetadataName( return null; } + // Edge case: if the symbol is a partial method, skip the implementation part and only process the partial method + // definition. This is needed because attributes will be reported as available on both the definition and the + // implementation part. To avoid generating duplicate files, we only give priority to the definition part. + // On Roslyn 4.3+, ForAttributeWithMetadataName will already only return the symbol the attribute was located on. + if (symbol is IMethodSymbol { IsPartialDefinition: false, PartialDefinitionPart: not null }) + { + return null; + } + // Create the GeneratorAttributeSyntaxContext value to pass to the input transform. The attributes array // will only ever have a single value, but that's fine with the attributes the various generators look for. GeneratorAttributeSyntaxContext syntaxContext = new( diff --git a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs index fa76a7126..158eefbba 100644 --- a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs +++ b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs @@ -268,6 +268,169 @@ partial class MyViewModel VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.MyViewModel.Test.g.cs", result)); } + // See https://github.com/CommunityToolkit/dotnet/issues/632 + [TestMethod] + public void RelayCommandMethodWithPartialDeclarations_TriggersCorrectly() + { + string source = """ + using CommunityToolkit.Mvvm.Input; + + #nullable enable + + namespace MyApp; + + partial class MyViewModel + { + [RelayCommand] + private partial void Test1() + { + } + + private partial void Test1(); + + private partial void Test2() + { + } + + [RelayCommand] + private partial void Test2(); + } + """; + + string result1 = """ + // + #pragma warning disable + #nullable enable + namespace MyApp + { + partial class MyViewModel + { + /// The backing field for . + [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")] + private global::CommunityToolkit.Mvvm.Input.RelayCommand? test1Command; + /// Gets an instance wrapping . + [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + public global::CommunityToolkit.Mvvm.Input.IRelayCommand Test1Command => test1Command ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(Test1)); + } + } + """; + + string result2 = """ + // + #pragma warning disable + #nullable enable + namespace MyApp + { + partial class MyViewModel + { + /// The backing field for . + [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")] + private global::CommunityToolkit.Mvvm.Input.RelayCommand? test2Command; + /// Gets an instance wrapping . + [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + public global::CommunityToolkit.Mvvm.Input.IRelayCommand Test2Command => test2Command ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(Test2)); + } + } + """; + + VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.MyViewModel.Test1.g.cs", result1), ("MyApp.MyViewModel.Test2.g.cs", result2)); + } + + // See https://github.com/CommunityToolkit/dotnet/issues/632 + [TestMethod] + public void RelayCommandMethodWithForwardedAttributesOverPartialDeclarations_MergesAttributes() + { + string source = """ + using CommunityToolkit.Mvvm.Input; + + #nullable enable + + namespace MyApp; + + partial class MyViewModel + { + [RelayCommand] + [field: Value(0)] + [property: Value(1)] + private partial void Test1() + { + } + + [field: Value(2)] + [property: Value(3)] + private partial void Test1(); + + [field: Value(0)] + [property: Value(1)] + private partial void Test2() + { + } + + [RelayCommand] + [field: Value(2)] + [property: Value(3)] + private partial void Test2(); + } + + public class ValueAttribute : Attribute + { + public ValueAttribute(object value) + { + } + } + """; + + string result1 = """ + // + #pragma warning disable + #nullable enable + namespace MyApp + { + partial class MyViewModel + { + /// The backing field for . + [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")] + [global::MyApp.ValueAttribute(2)] + [global::MyApp.ValueAttribute(0)] + private global::CommunityToolkit.Mvvm.Input.RelayCommand? test1Command; + /// Gets an instance wrapping . + [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::MyApp.ValueAttribute(3)] + [global::MyApp.ValueAttribute(1)] + public global::CommunityToolkit.Mvvm.Input.IRelayCommand Test1Command => test1Command ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(Test1)); + } + } + """; + + string result2 = """ + // + #pragma warning disable + #nullable enable + namespace MyApp + { + partial class MyViewModel + { + /// The backing field for . + [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")] + [global::MyApp.ValueAttribute(2)] + [global::MyApp.ValueAttribute(0)] + private global::CommunityToolkit.Mvvm.Input.RelayCommand? test2Command; + /// Gets an instance wrapping . + [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::MyApp.ValueAttribute(3)] + [global::MyApp.ValueAttribute(1)] + public global::CommunityToolkit.Mvvm.Input.IRelayCommand Test2Command => test2Command ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(Test2)); + } + } + """; + + VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.MyViewModel.Test1.g.cs", result1), ("MyApp.MyViewModel.Test2.g.cs", result2)); + } + [TestMethod] public void ObservablePropertyWithinGenericAndNestedTypes() { diff --git a/tests/CommunityToolkit.Mvvm.UnitTests/Test_RelayCommandAttribute.cs b/tests/CommunityToolkit.Mvvm.UnitTests/Test_RelayCommandAttribute.cs index 3114cf1e5..ecf4ff05d 100644 --- a/tests/CommunityToolkit.Mvvm.UnitTests/Test_RelayCommandAttribute.cs +++ b/tests/CommunityToolkit.Mvvm.UnitTests/Test_RelayCommandAttribute.cs @@ -659,6 +659,42 @@ static void ValidateTestAttribute(TestValidationAttribute testAttribute) Assert.AreEqual(testAttribute2.Animal, (Test_ObservablePropertyAttribute.Animal)67); } + // See https://github.com/CommunityToolkit/dotnet/issues/632 + [TestMethod] + public void Test_RelayCommandAttribute_WithPartialCommandMethodDefinitions() + { + ModelWithPartialCommandMethods model = new(); + + Assert.IsInstanceOfType(model.FooCommand); + Assert.IsInstanceOfType>(model.BarCommand); + Assert.IsInstanceOfType(model.BazCommand); + Assert.IsInstanceOfType(model.FooBarCommand); + + FieldInfo bazField = typeof(ModelWithPartialCommandMethods).GetField("bazCommand", BindingFlags.Instance | BindingFlags.NonPublic)!; + + Assert.IsNotNull(bazField.GetCustomAttribute()); + Assert.IsNotNull(bazField.GetCustomAttribute()); + Assert.AreEqual(bazField.GetCustomAttribute()!.Length, 1); + + PropertyInfo bazProperty = typeof(ModelWithPartialCommandMethods).GetProperty("BazCommand")!; + + Assert.IsNotNull(bazProperty.GetCustomAttribute()); + Assert.AreEqual(bazProperty.GetCustomAttribute()!.Length, 2); + Assert.IsNotNull(bazProperty.GetCustomAttribute()); + + FieldInfo fooBarField = typeof(ModelWithPartialCommandMethods).GetField("fooBarCommand", BindingFlags.Instance | BindingFlags.NonPublic)!; + + Assert.IsNotNull(fooBarField.GetCustomAttribute()); + Assert.IsNotNull(fooBarField.GetCustomAttribute()); + Assert.AreEqual(fooBarField.GetCustomAttribute()!.Length, 1); + + PropertyInfo fooBarProperty = typeof(ModelWithPartialCommandMethods).GetProperty("FooBarCommand")!; + + Assert.IsNotNull(fooBarProperty.GetCustomAttribute()); + Assert.AreEqual(fooBarProperty.GetCustomAttribute()!.Length, 2); + Assert.IsNotNull(fooBarProperty.GetCustomAttribute()); + } + #region Region public class Region { @@ -1202,4 +1238,44 @@ public TestValidationAttribute(object? o, Type t, bool flag, double d, string[] public Test_ObservablePropertyAttribute.Animal Animal { get; set; } } + + public partial class ModelWithPartialCommandMethods + { + [RelayCommand] + private partial void Foo(); + + private partial void Foo() + { + } + + private partial void Bar(string name); + + [RelayCommand] + private partial void Bar(string name) + { + } + + [RelayCommand] + [field: Required] + [property: MinLength(2)] + partial void Baz(); + + [field: MinLength(1)] + [property: XmlIgnore] + partial void Baz() + { + } + + [field: Required] + [property: MinLength(2)] + private partial Task FooBarAsync(); + + [RelayCommand] + [field: MinLength(1)] + [property: XmlIgnore] + private partial Task FooBarAsync() + { + return Task.CompletedTask; + } + } }