diff --git a/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs b/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs
index 535b7ce..ab91d59 100644
--- a/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs
+++ b/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs
@@ -113,6 +113,57 @@ public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullCondition
}
+ public override SyntaxNode? VisitSwitchExpression(SwitchExpressionSyntax node)
+ {
+ // Reverse arms order to start from the default value
+ var arms = node.Arms.Reverse();
+
+ ExpressionSyntax? currentExpression = null;
+
+ foreach (var arm in arms)
+ {
+ var armExpression = (ExpressionSyntax)Visit(arm.Expression);
+
+ // Handle fallback value
+ if (currentExpression == null)
+ {
+ currentExpression = arm.Pattern is DiscardPatternSyntax
+ ? armExpression
+ : SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression);
+
+ continue;
+ }
+
+ // Handle each arm, only if it's a constant expression
+ if (arm.Pattern is ConstantPatternSyntax constant)
+ {
+ ExpressionSyntax expression = SyntaxFactory.BinaryExpression(SyntaxKind.EqualsExpression, (ExpressionSyntax)Visit(node.GoverningExpression), constant.Expression);
+
+ // Add the when clause as a AND expression
+ if (arm.WhenClause != null)
+ {
+ expression = SyntaxFactory.BinaryExpression(
+ SyntaxKind.LogicalAndExpression,
+ expression,
+ (ExpressionSyntax)Visit(arm.WhenClause.Condition)
+ );
+ }
+
+ currentExpression = SyntaxFactory.ConditionalExpression(
+ expression,
+ armExpression,
+ currentExpression
+ );
+
+ continue;
+ }
+
+ throw new InvalidOperationException("Switch expressions rewriting is only supported with constant values");
+ }
+
+ return currentExpression;
+ }
+
public override SyntaxNode? VisitMemberBindingExpression(MemberBindingExpressionSyntax node)
{
if (_conditionalAccessExpressionsStack.Count > 0)
diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpression.verified.txt
new file mode 100644
index 0000000..00c6d5e
--- /dev/null
+++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpression.verified.txt
@@ -0,0 +1,15 @@
+//
+#nullable disable
+using EntityFrameworkCore.Projectables;
+
+namespace EntityFrameworkCore.Projectables.Generated
+{
+ [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
+ static class _Foo_SomeNumber
+ {
+ static global::System.Linq.Expressions.Expression> Expression()
+ {
+ return (global::Foo @this, int input) => input == 1 ? 2 : input == 3 ? 4 : input == 4 && @this.FancyNumber == 12 ? 48 : 1000;
+ }
+ }
+}
\ No newline at end of file
diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs
index b79842d..b2ea1fd 100644
--- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs
+++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs
@@ -1699,6 +1699,33 @@ class Foo {
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
+ [Fact]
+ public Task SwitchExpression()
+ {
+ var compilation = CreateCompilation(@"
+using EntityFrameworkCore.Projectables;
+
+class Foo {
+ public int? FancyNumber { get; set; }
+
+ [Projectable(NullConditionalRewriteSupport = NullConditionalRewriteSupport.Rewrite)]
+ public int SomeNumber(int input) => input switch {
+ 1 => 2,
+ 3 => 4,
+ 4 when FancyNumber == 12 => 48,
+ _ => 1000,
+ };
+ }
+");
+
+ var result = RunGenerator(compilation);
+
+ Assert.Empty(result.Diagnostics);
+ Assert.Single(result.GeneratedTrees);
+
+ return Verifier.Verify(result.GeneratedTrees[0].ToString());
+ }
+
[Fact]
public Task GenericTypes()
{