From 6a3159700fc2761892569e7e305cf8a344250f2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zvonimir=20Mati=C4=87?= Date: Tue, 11 Jun 2024 15:53:30 +0200 Subject: [PATCH] Optimized source generator by using ForAttributeWithMetadataName method and equality comparers for caching. --- ...ionSyntaxAndCompilationEqualityComparer.cs | 22 ++ ...MemberDeclarationSyntaxEqualityComparer.cs | 63 +++++ .../ProjectionExpressionGenerator.cs | 224 +++++++----------- 3 files changed, 173 insertions(+), 136 deletions(-) create mode 100644 src/EntityFrameworkCore.Projectables.Generator/MemberDeclarationSyntaxAndCompilationEqualityComparer.cs create mode 100644 src/EntityFrameworkCore.Projectables.Generator/MemberDeclarationSyntaxEqualityComparer.cs diff --git a/src/EntityFrameworkCore.Projectables.Generator/MemberDeclarationSyntaxAndCompilationEqualityComparer.cs b/src/EntityFrameworkCore.Projectables.Generator/MemberDeclarationSyntaxAndCompilationEqualityComparer.cs new file mode 100644 index 0000000..f835071 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/MemberDeclarationSyntaxAndCompilationEqualityComparer.cs @@ -0,0 +1,22 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace EntityFrameworkCore.Projectables.Generator; + +public class MemberDeclarationSyntaxAndCompilationEqualityComparer : IEqualityComparer<(MemberDeclarationSyntax, Compilation)> +{ + public bool Equals((MemberDeclarationSyntax, Compilation) x, (MemberDeclarationSyntax, Compilation) y) + { + return GetMemberDeclarationSyntaxAndCompilationName(x.Item1, x.Item2) == GetMemberDeclarationSyntaxAndCompilationName(y.Item1, y.Item2); + } + + public int GetHashCode((MemberDeclarationSyntax, Compilation) obj) + { + return GetMemberDeclarationSyntaxAndCompilationName(obj.Item1, obj.Item2).GetHashCode(); + } + + public static string GetMemberDeclarationSyntaxAndCompilationName(MemberDeclarationSyntax memberDeclarationSyntax, Compilation compilation) + { + return $"{compilation.AssemblyName}:{MemberDeclarationSyntaxEqualityComparer.GetMemberDeclarationSyntaxName(memberDeclarationSyntax)}"; + } +} diff --git a/src/EntityFrameworkCore.Projectables.Generator/MemberDeclarationSyntaxEqualityComparer.cs b/src/EntityFrameworkCore.Projectables.Generator/MemberDeclarationSyntaxEqualityComparer.cs new file mode 100644 index 0000000..f2c49a5 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/MemberDeclarationSyntaxEqualityComparer.cs @@ -0,0 +1,63 @@ +using System.Text; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace EntityFrameworkCore.Projectables.Generator; + +public class MemberDeclarationSyntaxEqualityComparer : IEqualityComparer +{ + public bool Equals(MemberDeclarationSyntax x, MemberDeclarationSyntax y) + { + return GetMemberDeclarationSyntaxName(x) == GetMemberDeclarationSyntaxName(y); + } + + public int GetHashCode(MemberDeclarationSyntax obj) + { + return GetMemberDeclarationSyntaxName(obj).GetHashCode(); + } + + public static string GetMemberDeclarationSyntaxName(MemberDeclarationSyntax memberDeclaration) + { + var sb = new StringBuilder(); + + // Get the member name + if (memberDeclaration is MethodDeclarationSyntax methodDeclaration) + { + sb.Append(methodDeclaration.Identifier.Text); + } + else if (memberDeclaration is PropertyDeclarationSyntax propertyDeclaration) + { + sb.Append(propertyDeclaration.Identifier.Text); + } + else if (memberDeclaration is FieldDeclarationSyntax fieldDeclaration) + { + sb.Append(string.Join(", ", fieldDeclaration.Declaration.Variables.Select(v => v.Identifier.Text))); + } + + // Traverse up the tree to get containing type names + var parent = memberDeclaration.Parent; + while (parent != null) + { + switch (parent) + { + case NamespaceDeclarationSyntax namespaceDeclaration: + sb.Insert(0, namespaceDeclaration.Name + "."); + break; + case ClassDeclarationSyntax classDeclaration: + sb.Insert(0, classDeclaration.Identifier.Text + "."); + break; + case StructDeclarationSyntax structDeclaration: + sb.Insert(0, structDeclaration.Identifier.Text + "."); + break; + case InterfaceDeclarationSyntax interfaceDeclaration: + sb.Insert(0, interfaceDeclaration.Identifier.Text + "."); + break; + case EnumDeclarationSyntax enumDeclaration: + sb.Insert(0, enumDeclaration.Identifier.Text + "."); + break; + } + parent = parent.Parent; + } + + return sb.ToString(); + } +} diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index 76ce700..6f177bc 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -3,15 +3,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; -using System; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Diagnostics; -using System.Linq; -using System.Security.Cryptography.X509Certificates; using System.Text; -using System.Threading; -using System.Threading.Tasks; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace EntityFrameworkCore.Projectables.Generator @@ -41,167 +33,127 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { // Do a simple filter for members IncrementalValuesProvider memberDeclarations = context.SyntaxProvider - .CreateSyntaxProvider( - predicate: static (s, _) => s is MemberDeclarationSyntax m && m.AttributeLists.Count > 0, - transform: static (c, _) => GetSemanticTargetForGeneration(c)) - .Where(static m => m is not null)!; // filter out attributed enums that we don't care about + .ForAttributeWithMetadataName( + ProjectablesAttributeName, + predicate: static (s, _) => s is MemberDeclarationSyntax, + transform: static (c, _) => (MemberDeclarationSyntax)c.TargetNode) + .WithComparer(new MemberDeclarationSyntaxEqualityComparer()); // Combine the selected enums with the `Compilation` - IncrementalValueProvider<(Compilation, ImmutableArray)> compilationAndEnums - = context.CompilationProvider.Combine(memberDeclarations.Collect()); + IncrementalValuesProvider<(MemberDeclarationSyntax, Compilation)> compilationAndMemberPairs = memberDeclarations + .Combine(context.CompilationProvider) + .WithComparer(new MemberDeclarationSyntaxAndCompilationEqualityComparer()); // Generate the source using the compilation and enums - context.RegisterImplementationSourceOutput(compilationAndEnums, + context.RegisterImplementationSourceOutput(compilationAndMemberPairs, static (spc, source) => Execute(source.Item1, source.Item2, spc)); } - static MemberDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context) + static void Execute(MemberDeclarationSyntax member, Compilation compilation, SourceProductionContext context) { - // we know the node is a MemberDeclarationSyntax - var memberDeclarationSyntax = (MemberDeclarationSyntax)context.Node; + var projectable = ProjectableInterpreter.GetDescriptor(compilation, member, context); - // loop through all the attributes on the method - foreach (var attributeListSyntax in memberDeclarationSyntax.AttributeLists) + if (projectable is null) { - foreach (var attributeSyntax in attributeListSyntax.Attributes) - { - if (context.SemanticModel.GetSymbolInfo(attributeSyntax).Symbol is not IMethodSymbol attributeSymbol) - { - // weird, we couldn't get the symbol, ignore it - continue; - } - - var attributeContainingTypeSymbol = attributeSymbol.ContainingType; - var fullName = attributeContainingTypeSymbol.ToDisplayString(); - - // Is the attribute the [Projcetable] attribute? - if (fullName == ProjectablesAttributeName) - { - // return the enum - return memberDeclarationSyntax; - } - } - } - - // we didn't find the attribute we were looking for - return null; - } - - static void Execute(Compilation compilation, ImmutableArray members, SourceProductionContext context) - { - if (members.IsDefaultOrEmpty) - { - // nothing to do yet return; } - var projectables = members - .Select(x => ProjectableInterpreter.GetDescriptor(compilation, x, context)) - .Where(x => x is not null) - .Select(x => x!); - - var resultBuilder = new StringBuilder(); - - foreach (var projectable in projectables) + if (projectable.MemberName is null) { - if (projectable.MemberName is null) - { - throw new InvalidOperationException("Expected a memberName here"); - } + throw new InvalidOperationException("Expected a memberName here"); + } - var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName); - var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs"; + var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName); + var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs"; - var classSyntax = ClassDeclaration(generatedClassName) - .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) - .WithTypeParameterList(projectable.ClassTypeParameterList) - .WithConstraintClauses(projectable.ClassConstraintClauses ?? List()) - .AddAttributeLists( - AttributeList() - .AddAttributes(_editorBrowsableAttribute) - ) - .AddMembers( - MethodDeclaration( - GenericName( - Identifier("global::System.Linq.Expressions.Expression"), - TypeArgumentList( - SingletonSeparatedList( - (TypeSyntax)GenericName( - Identifier("global::System.Func"), - GetLambdaTypeArgumentListSyntax(projectable) - ) + var classSyntax = ClassDeclaration(generatedClassName) + .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) + .WithTypeParameterList(projectable.ClassTypeParameterList) + .WithConstraintClauses(projectable.ClassConstraintClauses ?? List()) + .AddAttributeLists( + AttributeList() + .AddAttributes(_editorBrowsableAttribute) + ) + .AddMembers( + MethodDeclaration( + GenericName( + Identifier("global::System.Linq.Expressions.Expression"), + TypeArgumentList( + SingletonSeparatedList( + (TypeSyntax)GenericName( + Identifier("global::System.Func"), + GetLambdaTypeArgumentListSyntax(projectable) ) ) - ), - "Expression" - ) - .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) - .WithTypeParameterList(projectable.TypeParameterList) - .WithConstraintClauses(projectable.ConstraintClauses ?? List()) - .WithBody( - Block( - ReturnStatement( - ParenthesizedLambdaExpression( - projectable.ParametersList ?? ParameterList(), - null, - projectable.ExpressionBody - ) + ) + ), + "Expression" + ) + .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) + .WithTypeParameterList(projectable.TypeParameterList) + .WithConstraintClauses(projectable.ConstraintClauses ?? List()) + .WithBody( + Block( + ReturnStatement( + ParenthesizedLambdaExpression( + projectable.ParametersList ?? ParameterList(), + null, + projectable.ExpressionBody ) ) - ) - ); + ) + ) + ); #nullable disable - var compilationUnit = CompilationUnit(); + var compilationUnit = CompilationUnit(); - foreach (var usingDirective in projectable.UsingDirectives) - { - compilationUnit = compilationUnit.AddUsings(usingDirective); - } + foreach (var usingDirective in projectable.UsingDirectives) + { + compilationUnit = compilationUnit.AddUsings(usingDirective); + } - if (projectable.ClassNamespace is not null) - { - compilationUnit = compilationUnit.AddUsings( - UsingDirective( - ParseName(projectable.ClassNamespace) - ) - ); - } + if (projectable.ClassNamespace is not null) + { + compilationUnit = compilationUnit.AddUsings( + UsingDirective( + ParseName(projectable.ClassNamespace) + ) + ); + } - compilationUnit = compilationUnit - .AddMembers( - NamespaceDeclaration( - ParseName("EntityFrameworkCore.Projectables.Generated") - ).AddMembers(classSyntax) + compilationUnit = compilationUnit + .AddMembers( + NamespaceDeclaration( + ParseName("EntityFrameworkCore.Projectables.Generated") + ).AddMembers(classSyntax) + ) + .WithLeadingTrivia( + TriviaList( + Comment("// "), + Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)) ) - .WithLeadingTrivia( - TriviaList( - Comment("// "), - Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)) - ) - ); + ); - context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8)); + context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8)); + static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable) + { + var lambdaTypeArguments = TypeArgumentList( + SeparatedList( + // TODO: Document where clause + projectable.ParametersList?.Parameters.Where(p => p.Type is not null).Select(p => p.Type!) + ) + ); - static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable) + if (projectable.ReturnTypeName is not null) { - var lambdaTypeArguments = TypeArgumentList( - SeparatedList( - // TODO: Document where clause - projectable.ParametersList?.Parameters.Where(p => p.Type is not null).Select(p => p.Type!) - ) - ); - - if (projectable.ReturnTypeName is not null) - { - lambdaTypeArguments = lambdaTypeArguments.AddArguments(ParseTypeName(projectable.ReturnTypeName)); - } - - return lambdaTypeArguments; + lambdaTypeArguments = lambdaTypeArguments.AddArguments(ParseTypeName(projectable.ReturnTypeName)); } + + return lambdaTypeArguments; } } }