diff --git a/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs b/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs index 535b7ce..ab91d59 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs @@ -113,6 +113,57 @@ public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullCondition } + public override SyntaxNode? VisitSwitchExpression(SwitchExpressionSyntax node) + { + // Reverse arms order to start from the default value + var arms = node.Arms.Reverse(); + + ExpressionSyntax? currentExpression = null; + + foreach (var arm in arms) + { + var armExpression = (ExpressionSyntax)Visit(arm.Expression); + + // Handle fallback value + if (currentExpression == null) + { + currentExpression = arm.Pattern is DiscardPatternSyntax + ? armExpression + : SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression); + + continue; + } + + // Handle each arm, only if it's a constant expression + if (arm.Pattern is ConstantPatternSyntax constant) + { + ExpressionSyntax expression = SyntaxFactory.BinaryExpression(SyntaxKind.EqualsExpression, (ExpressionSyntax)Visit(node.GoverningExpression), constant.Expression); + + // Add the when clause as a AND expression + if (arm.WhenClause != null) + { + expression = SyntaxFactory.BinaryExpression( + SyntaxKind.LogicalAndExpression, + expression, + (ExpressionSyntax)Visit(arm.WhenClause.Condition) + ); + } + + currentExpression = SyntaxFactory.ConditionalExpression( + expression, + armExpression, + currentExpression + ); + + continue; + } + + throw new InvalidOperationException("Switch expressions rewriting is only supported with constant values"); + } + + return currentExpression; + } + public override SyntaxNode? VisitMemberBindingExpression(MemberBindingExpressionSyntax node) { if (_conditionalAccessExpressionsStack.Count > 0) diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpression.verified.txt new file mode 100644 index 0000000..00c6d5e --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpression.verified.txt @@ -0,0 +1,15 @@ +// +#nullable disable +using EntityFrameworkCore.Projectables; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class _Foo_SomeNumber + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo @this, int input) => input == 1 ? 2 : input == 3 ? 4 : input == 4 && @this.FancyNumber == 12 ? 48 : 1000; + } + } +} \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs index b79842d..b2ea1fd 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs @@ -1699,6 +1699,33 @@ class Foo { return Verifier.Verify(result.GeneratedTrees[0].ToString()); } + [Fact] + public Task SwitchExpression() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; + +class Foo { + public int? FancyNumber { get; set; } + + [Projectable(NullConditionalRewriteSupport = NullConditionalRewriteSupport.Rewrite)] + public int SomeNumber(int input) => input switch { + 1 => 2, + 3 => 4, + 4 when FancyNumber == 12 => 48, + _ => 1000, + }; + } +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Single(result.GeneratedTrees); + + return Verifier.Verify(result.GeneratedTrees[0].ToString()); + } + [Fact] public Task GenericTypes() {