Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace EntityFrameworkCore.Projectables.Generator;

public class MemberDeclarationSyntaxAndCompilationEqualityComparer : IEqualityComparer<(MemberDeclarationSyntax, Compilation)>
{
public bool Equals((MemberDeclarationSyntax, Compilation) x, (MemberDeclarationSyntax, Compilation) y)
{
return GetMemberDeclarationSyntaxAndCompilationName(x.Item1, x.Item2) == GetMemberDeclarationSyntaxAndCompilationName(y.Item1, y.Item2);
}

public int GetHashCode((MemberDeclarationSyntax, Compilation) obj)
{
return GetMemberDeclarationSyntaxAndCompilationName(obj.Item1, obj.Item2).GetHashCode();
}

public static string GetMemberDeclarationSyntaxAndCompilationName(MemberDeclarationSyntax memberDeclarationSyntax, Compilation compilation)
{
return $"{compilation.AssemblyName}:{MemberDeclarationSyntaxEqualityComparer.GetMemberDeclarationSyntaxName(memberDeclarationSyntax)}";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using System.Text;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace EntityFrameworkCore.Projectables.Generator;

public class MemberDeclarationSyntaxEqualityComparer : IEqualityComparer<MemberDeclarationSyntax>
{
public bool Equals(MemberDeclarationSyntax x, MemberDeclarationSyntax y)
{
return GetMemberDeclarationSyntaxName(x) == GetMemberDeclarationSyntaxName(y);
}

public int GetHashCode(MemberDeclarationSyntax obj)
{
return GetMemberDeclarationSyntaxName(obj).GetHashCode();
}

public static string GetMemberDeclarationSyntaxName(MemberDeclarationSyntax memberDeclaration)
{
var sb = new StringBuilder();

// Get the member name
if (memberDeclaration is MethodDeclarationSyntax methodDeclaration)
{
sb.Append(methodDeclaration.Identifier.Text);
}
else if (memberDeclaration is PropertyDeclarationSyntax propertyDeclaration)
{
sb.Append(propertyDeclaration.Identifier.Text);
}
else if (memberDeclaration is FieldDeclarationSyntax fieldDeclaration)
{
sb.Append(string.Join(", ", fieldDeclaration.Declaration.Variables.Select(v => v.Identifier.Text)));
}

// Traverse up the tree to get containing type names
var parent = memberDeclaration.Parent;
while (parent != null)
{
switch (parent)
{
case NamespaceDeclarationSyntax namespaceDeclaration:
sb.Insert(0, namespaceDeclaration.Name + ".");
break;
case ClassDeclarationSyntax classDeclaration:
sb.Insert(0, classDeclaration.Identifier.Text + ".");
break;
case StructDeclarationSyntax structDeclaration:
sb.Insert(0, structDeclaration.Identifier.Text + ".");
break;
case InterfaceDeclarationSyntax interfaceDeclaration:
sb.Insert(0, interfaceDeclaration.Identifier.Text + ".");
break;
case EnumDeclarationSyntax enumDeclaration:
sb.Insert(0, enumDeclaration.Identifier.Text + ".");
break;
}
parent = parent.Parent;
}

return sb.ToString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,7 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace EntityFrameworkCore.Projectables.Generator
Expand Down Expand Up @@ -41,167 +33,127 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Do a simple filter for members
IncrementalValuesProvider<MemberDeclarationSyntax> memberDeclarations = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => s is MemberDeclarationSyntax m && m.AttributeLists.Count > 0,
transform: static (c, _) => GetSemanticTargetForGeneration(c))
.Where(static m => m is not null)!; // filter out attributed enums that we don't care about
.ForAttributeWithMetadataName(
ProjectablesAttributeName,
predicate: static (s, _) => s is MemberDeclarationSyntax,
transform: static (c, _) => (MemberDeclarationSyntax)c.TargetNode)
.WithComparer(new MemberDeclarationSyntaxEqualityComparer());

// Combine the selected enums with the `Compilation`
IncrementalValueProvider<(Compilation, ImmutableArray<MemberDeclarationSyntax>)> compilationAndEnums
= context.CompilationProvider.Combine(memberDeclarations.Collect());
IncrementalValuesProvider<(MemberDeclarationSyntax, Compilation)> compilationAndMemberPairs = memberDeclarations
.Combine(context.CompilationProvider)
.WithComparer(new MemberDeclarationSyntaxAndCompilationEqualityComparer());

// Generate the source using the compilation and enums
context.RegisterImplementationSourceOutput(compilationAndEnums,
context.RegisterImplementationSourceOutput(compilationAndMemberPairs,
static (spc, source) => Execute(source.Item1, source.Item2, spc));
}

static MemberDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
static void Execute(MemberDeclarationSyntax member, Compilation compilation, SourceProductionContext context)
{
// we know the node is a MemberDeclarationSyntax
var memberDeclarationSyntax = (MemberDeclarationSyntax)context.Node;
var projectable = ProjectableInterpreter.GetDescriptor(compilation, member, context);

// loop through all the attributes on the method
foreach (var attributeListSyntax in memberDeclarationSyntax.AttributeLists)
if (projectable is null)
{
foreach (var attributeSyntax in attributeListSyntax.Attributes)
{
if (context.SemanticModel.GetSymbolInfo(attributeSyntax).Symbol is not IMethodSymbol attributeSymbol)
{
// weird, we couldn't get the symbol, ignore it
continue;
}

var attributeContainingTypeSymbol = attributeSymbol.ContainingType;
var fullName = attributeContainingTypeSymbol.ToDisplayString();

// Is the attribute the [Projcetable] attribute?
if (fullName == ProjectablesAttributeName)
{
// return the enum
return memberDeclarationSyntax;
}
}
}

// we didn't find the attribute we were looking for
return null;
}

static void Execute(Compilation compilation, ImmutableArray<MemberDeclarationSyntax> members, SourceProductionContext context)
{
if (members.IsDefaultOrEmpty)
{
// nothing to do yet
return;
}

var projectables = members
.Select(x => ProjectableInterpreter.GetDescriptor(compilation, x, context))
.Where(x => x is not null)
.Select(x => x!);

var resultBuilder = new StringBuilder();

foreach (var projectable in projectables)
if (projectable.MemberName is null)
{
if (projectable.MemberName is null)
{
throw new InvalidOperationException("Expected a memberName here");
}
throw new InvalidOperationException("Expected a memberName here");
}

var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName);
var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs";
var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName);
var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs";

var classSyntax = ClassDeclaration(generatedClassName)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.ClassTypeParameterList)
.WithConstraintClauses(projectable.ClassConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.AddAttributeLists(
AttributeList()
.AddAttributes(_editorBrowsableAttribute)
)
.AddMembers(
MethodDeclaration(
GenericName(
Identifier("global::System.Linq.Expressions.Expression"),
TypeArgumentList(
SingletonSeparatedList(
(TypeSyntax)GenericName(
Identifier("global::System.Func"),
GetLambdaTypeArgumentListSyntax(projectable)
)
var classSyntax = ClassDeclaration(generatedClassName)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.ClassTypeParameterList)
.WithConstraintClauses(projectable.ClassConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.AddAttributeLists(
AttributeList()
.AddAttributes(_editorBrowsableAttribute)
)
.AddMembers(
MethodDeclaration(
GenericName(
Identifier("global::System.Linq.Expressions.Expression"),
TypeArgumentList(
SingletonSeparatedList(
(TypeSyntax)GenericName(
Identifier("global::System.Func"),
GetLambdaTypeArgumentListSyntax(projectable)
)
)
),
"Expression"
)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.TypeParameterList)
.WithConstraintClauses(projectable.ConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.WithBody(
Block(
ReturnStatement(
ParenthesizedLambdaExpression(
projectable.ParametersList ?? ParameterList(),
null,
projectable.ExpressionBody
)
)
),
"Expression"
)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.TypeParameterList)
.WithConstraintClauses(projectable.ConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.WithBody(
Block(
ReturnStatement(
ParenthesizedLambdaExpression(
projectable.ParametersList ?? ParameterList(),
null,
projectable.ExpressionBody
)
)
)
);
)
)
);

#nullable disable

var compilationUnit = CompilationUnit();
var compilationUnit = CompilationUnit();

foreach (var usingDirective in projectable.UsingDirectives)
{
compilationUnit = compilationUnit.AddUsings(usingDirective);
}
foreach (var usingDirective in projectable.UsingDirectives)
{
compilationUnit = compilationUnit.AddUsings(usingDirective);
}

if (projectable.ClassNamespace is not null)
{
compilationUnit = compilationUnit.AddUsings(
UsingDirective(
ParseName(projectable.ClassNamespace)
)
);
}
if (projectable.ClassNamespace is not null)
{
compilationUnit = compilationUnit.AddUsings(
UsingDirective(
ParseName(projectable.ClassNamespace)
)
);
}

compilationUnit = compilationUnit
.AddMembers(
NamespaceDeclaration(
ParseName("EntityFrameworkCore.Projectables.Generated")
).AddMembers(classSyntax)
compilationUnit = compilationUnit
.AddMembers(
NamespaceDeclaration(
ParseName("EntityFrameworkCore.Projectables.Generated")
).AddMembers(classSyntax)
)
.WithLeadingTrivia(
TriviaList(
Comment("// <auto-generated/>"),
Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true))
)
.WithLeadingTrivia(
TriviaList(
Comment("// <auto-generated/>"),
Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true))
)
);
);


context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8));
context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8));

static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable)
{
var lambdaTypeArguments = TypeArgumentList(
SeparatedList(
// TODO: Document where clause
projectable.ParametersList?.Parameters.Where(p => p.Type is not null).Select(p => p.Type!)
)
);

static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable)
if (projectable.ReturnTypeName is not null)
{
var lambdaTypeArguments = TypeArgumentList(
SeparatedList(
// TODO: Document where clause
projectable.ParametersList?.Parameters.Where(p => p.Type is not null).Select(p => p.Type!)
)
);

if (projectable.ReturnTypeName is not null)
{
lambdaTypeArguments = lambdaTypeArguments.AddArguments(ParseTypeName(projectable.ReturnTypeName));
}

return lambdaTypeArguments;
lambdaTypeArguments = lambdaTypeArguments.AddArguments(ParseTypeName(projectable.ReturnTypeName));
}

return lambdaTypeArguments;
}
}
}
Expand Down