Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Generator/Generator.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net5.0</TargetFramework>
<TargetFramework>net6.0</TargetFramework>
<Nullable>enable</Nullable>
</PropertyGroup>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net5.0</TargetFramework>
<TargetFramework>net6.0</TargetFramework>
<Nullable>enable</Nullable>

<IsPackable>false</IsPackable>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.SourceGenerators.Testing.XUnit" Version="1.1.1" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="3.9.0" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="4.0.1" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.2.0" />
<PackageReference Include="xunit" Version="2.4.1" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.5">
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net5.0</TargetFramework>
<TargetFramework>net6.0</TargetFramework>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
</PropertyGroup>
Expand Down
2 changes: 1 addition & 1 deletion OneOf.SourceGenerator/OneOf.SourceGenerator.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="3.9.0">
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.0.1">
<PrivateAssets>all</PrivateAssets>
</PackageReference>
</ItemGroup>
Expand Down
158 changes: 72 additions & 86 deletions OneOf.SourceGenerator/OneOfGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;

namespace OneOf.SourceGenerator
{
[Generator]
public class OneOfGenerator : ISourceGenerator
public class OneOfGenerator : IIncrementalGenerator
{
private const string AttributeName = "GenerateOneOfAttribute";
private const string AttributeNamespace = "OneOf";
Expand All @@ -29,55 +28,95 @@ internal sealed class {AttributeName} : Attribute
}}
";

public void Execute(GeneratorExecutionContext context)
public void Initialize(IncrementalGeneratorInitializationContext context)
{
if (context.SyntaxReceiver is not OneOfSyntaxReceiver receiver)
context.RegisterPostInitializationOutput(ctx => ctx.AddSource($"{AttributeName}.g.cs", _attributeText));

var oneOfClasses = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => IsSyntaxTargetForGeneration(s),
transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx))
.Where(static m => m is not null)
.Collect();

context.RegisterSourceOutput(oneOfClasses, Execute);


static bool IsSyntaxTargetForGeneration(SyntaxNode node)
{
return;
return node is ClassDeclarationSyntax {AttributeLists: {Count: > 0}} classDeclarationSyntax
&& classDeclarationSyntax.Modifiers.Any(SyntaxKind.PartialKeyword);
}

Compilation compilation = context.Compilation;
static INamedTypeSymbol? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
{
var symbol = context.SemanticModel.GetDeclaredSymbol(context.Node);

INamedTypeSymbol? attributeSymbol =
compilation.GetTypeByMetadataName($"{AttributeNamespace}.{AttributeName}");
if (symbol is not INamedTypeSymbol namedTypeSymbol)
{
return null;
}

var attributeData = namedTypeSymbol.GetAttributes().FirstOrDefault(ad =>
string.Equals(ad.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), $"global::{AttributeNamespace}.{AttributeName}"));

if (attributeSymbol is null)
{
return;
return attributeData is null ? null : namedTypeSymbol;
}
}

List<(INamedTypeSymbol, Location?)> namedTypeSymbols = new();
foreach (ClassDeclarationSyntax classDeclaration in receiver.CandidateClasses)
{
SemanticModel model = compilation.GetSemanticModel(classDeclaration.SyntaxTree);
INamedTypeSymbol? namedTypeSymbol = model.GetDeclaredSymbol(classDeclaration);
private static string GenerateClassSource(INamedTypeSymbol classSymbol,
ImmutableArray<ITypeParameterSymbol> typeParameters, ImmutableArray<ITypeSymbol> typeArguments)
{
var paramArgPairs =
typeParameters.Zip(typeArguments, (param, arg) => (param, arg));

AttributeData? attributeData = namedTypeSymbol?.GetAttributes().FirstOrDefault(ad =>
ad.AttributeClass?.Equals(attributeSymbol, SymbolEqualityComparer.Default) != false);
var oneOfGenericPart = GetGenericPart(typeArguments);

if (attributeData is not null)
{
namedTypeSymbols.Add((namedTypeSymbol!,
attributeData.ApplicationSyntaxReference?.GetSyntax().GetLocation()));
}
}
var classNameWithGenericTypes = $"{classSymbol.Name}{GetOpenGenericPart(classSymbol)}";

StringBuilder source = new($@"// <auto-generated />
#pragma warning disable 1591

foreach ((INamedTypeSymbol namedSymbol, Location? attributeLocation) in namedTypeSymbols)
namespace {classSymbol.ContainingNamespace.ToDisplayString()}
{{
partial class {classNameWithGenericTypes}");

source.Append($@"
{{
public {classSymbol.Name}(OneOf.OneOf<{oneOfGenericPart}> _) : base(_) {{ }}
");

foreach (var (param, arg) in paramArgPairs)
{
string? classSource = ProcessClass(namedSymbol, context, attributeLocation);
source.Append($@"
public static implicit operator {classNameWithGenericTypes}({arg.ToDisplayString()} _) => new {classNameWithGenericTypes}(_);
public static explicit operator {arg.ToDisplayString()}({classNameWithGenericTypes} _) => _.As{param.Name};
");
}

source.Append(@" }
}");
return source.ToString();
}

private static void Execute(SourceProductionContext context, ImmutableArray<INamedTypeSymbol?> symbols)
{
foreach (var namedTypeSymbol in symbols.Where(symbol => symbol is not null))
{
var classSource = ProcessClass(namedTypeSymbol!, context);

if (classSource is null)
{
continue;
}

context.AddSource($"{namedSymbol.ContainingNamespace}_{namedSymbol.Name}.g.cs", classSource);
context.AddSource($"{namedTypeSymbol!.ContainingNamespace}_{namedTypeSymbol.Name}.g.cs", classSource);
}
}

private static string? ProcessClass(INamedTypeSymbol classSymbol, GeneratorExecutionContext context, Location? attributeLocation)
private static string? ProcessClass(INamedTypeSymbol classSymbol, SourceProductionContext context)
{
attributeLocation ??= Location.None;
var attributeLocation = classSymbol.Locations.FirstOrDefault() ?? Location.None;

if (!classSymbol.ContainingSymbol.Equals(classSymbol.ContainingNamespace, SymbolEqualityComparer.Default))
{
Expand All @@ -91,9 +130,9 @@ public void Execute(GeneratorExecutionContext context)
return null;
}

ImmutableArray<ITypeSymbol> typeArguments = classSymbol.BaseType.TypeArguments;
var typeArguments = classSymbol.BaseType.TypeArguments;

foreach (ITypeSymbol typeSymbol in typeArguments)
foreach (var typeSymbol in typeArguments)
{
if (typeSymbol.Name == nameof(Object))
{
Expand All @@ -111,42 +150,10 @@ public void Execute(GeneratorExecutionContext context)
return GenerateClassSource(classSymbol, classSymbol.BaseType.TypeParameters, typeArguments);

void CreateDiagnosticError(DiagnosticDescriptor descriptor)
=> context.ReportDiagnostic(Diagnostic.Create(descriptor, attributeLocation, classSymbol.Name, DiagnosticSeverity.Error));
}

private static string GenerateClassSource(INamedTypeSymbol classSymbol,
ImmutableArray<ITypeParameterSymbol> typeParameters, ImmutableArray<ITypeSymbol> typeArguments)
{
IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs =
typeParameters.Zip(typeArguments, (param, arg) => (param, arg));

string oneOfGenericPart = GetGenericPart(typeArguments);

string classNameWithGenericTypes = $"{classSymbol.Name}{GetOpenGenericPart(classSymbol)}";

StringBuilder source = new($@"// <auto-generated />
#pragma warning disable 1591

namespace {classSymbol.ContainingNamespace.ToDisplayString()}
{{
partial class {classNameWithGenericTypes}");

source.Append($@"
{{
public {classSymbol.Name}(OneOf.OneOf<{oneOfGenericPart}> _) : base(_) {{ }}
");

foreach ((ITypeParameterSymbol param, ITypeSymbol arg) in paramArgPairs)
{
source.Append($@"
public static implicit operator {classNameWithGenericTypes}({arg.ToDisplayString()} _) => new {classNameWithGenericTypes}(_);
public static explicit operator {arg.ToDisplayString()}({classNameWithGenericTypes} _) => _.As{param.Name};
");
context.ReportDiagnostic(Diagnostic.Create(descriptor, attributeLocation, classSymbol.Name,
DiagnosticSeverity.Error));
}

source.Append(@" }
}");
return source.ToString();
}

private static string GetGenericPart(ImmutableArray<ITypeSymbol> typeArguments) =>
Expand All @@ -161,26 +168,5 @@ private static string GetGenericPart(ImmutableArray<ITypeSymbol> typeArguments)

return $"<{GetGenericPart(classSymbol.TypeArguments)}>";
}

public void Initialize(GeneratorInitializationContext context)
{
context.RegisterForPostInitialization(ctx =>
ctx.AddSource($"{AttributeName}.g.cs", _attributeText));
context.RegisterForSyntaxNotifications(() => new OneOfSyntaxReceiver());
}

internal class OneOfSyntaxReceiver : ISyntaxReceiver
{
public List<ClassDeclarationSyntax> CandidateClasses { get; } = new();

public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
{
if (syntaxNode is ClassDeclarationSyntax { AttributeLists: { Count: > 0 } } classDeclarationSyntax
&& classDeclarationSyntax.Modifiers.Any(SyntaxKind.PartialKeyword))
{
CandidateClasses.Add(classDeclarationSyntax);
}
}
}
}
}