diff --git a/Directory.Build.props b/Directory.Build.props index 726e407..0983617 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -1,6 +1,6 @@ - net8.0;net10.0 + net8.0;net9.0;net10.0 true 12.0 enable diff --git a/Directory.Packages.props b/Directory.Packages.props index 7f8e683..0c92090 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -4,18 +4,21 @@ + + + diff --git a/src/EntityFrameworkCore.Projectables.Generator/AnalyzerReleases.Unshipped.md b/src/EntityFrameworkCore.Projectables.Generator/AnalyzerReleases.Unshipped.md index ef168b8..f6e4fb7 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/AnalyzerReleases.Unshipped.md +++ b/src/EntityFrameworkCore.Projectables.Generator/AnalyzerReleases.Unshipped.md @@ -3,3 +3,4 @@ Rule ID | Category | Severity | Notes --------|----------|----------|-------------------- EFP0002 | Design | Error | +EFP0003 | Design | Error | diff --git a/src/EntityFrameworkCore.Projectables.Generator/Diagnostics.cs b/src/EntityFrameworkCore.Projectables.Generator/Diagnostics.cs index 18f87c1..718240e 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/Diagnostics.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/Diagnostics.cs @@ -25,5 +25,13 @@ public static class Diagnostics DiagnosticSeverity.Error, isEnabledByDefault: true); + public static readonly DiagnosticDescriptor SqlExpressionArgumentCountMismatch = new DiagnosticDescriptor( + id: "EFP0003", + title: "SqlExpression template references out-of-range argument", + messageFormat: "SQL template references argument {{{0}}} but the method only has {1} parameter(s). Valid argument indices range from {{0}} to {{{2}}}.", + category: "Design", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + } } diff --git a/src/EntityFrameworkCore.Projectables.Generator/SqlExpressionAnalyzer.cs b/src/EntityFrameworkCore.Projectables.Generator/SqlExpressionAnalyzer.cs new file mode 100644 index 0000000..ea1d63b --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/SqlExpressionAnalyzer.cs @@ -0,0 +1,76 @@ +using System.Collections.Immutable; +using System.Text.RegularExpressions; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace EntityFrameworkCore.Projectables.Generator +{ + /// + /// Validates that [SqlExpression] SQL templates do not reference argument indices that + /// are out of range for the decorated method's parameter list. + /// + [DiagnosticAnalyzer(LanguageNames.CSharp)] + public class SqlExpressionAnalyzer : DiagnosticAnalyzer + { + private const string SqlExpressionAttributeFullName = "EntityFrameworkCore.Projectables.SqlExpressionAttribute"; + + // Matches any {N} placeholder in the SQL template + private static readonly Regex PlaceholderPattern = + new Regex(@"\{(\d+)\}", RegexOptions.Compiled); + + public override ImmutableArray SupportedDiagnostics => + ImmutableArray.Create(Diagnostics.SqlExpressionArgumentCountMismatch); + + public override void Initialize(AnalysisContext context) + { + context.EnableConcurrentExecution(); + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); + context.RegisterSymbolAction(AnalyzeMethod, SymbolKind.Method); + } + + private static void AnalyzeMethod(SymbolAnalysisContext context) + { + var method = (IMethodSymbol)context.Symbol; + + var sqlExprAttrType = context.Compilation.GetTypeByMetadataName(SqlExpressionAttributeFullName); + if (sqlExprAttrType is null) + return; + + var paramCount = method.Parameters.Length; + + foreach (var attr in method.GetAttributes()) + { + if (!SymbolEqualityComparer.Default.Equals(attr.AttributeClass, sqlExprAttrType)) + continue; + + if (attr.ConstructorArguments.Length == 0) + continue; + + var sqlTemplate = attr.ConstructorArguments[0].Value as string; + if (sqlTemplate is null) + continue; + + var maxIndex = -1; + foreach (Match m in PlaceholderPattern.Matches(sqlTemplate)) + { + var idx = int.Parse(m.Groups[1].Value); + if (idx > maxIndex) + maxIndex = idx; + } + + if (maxIndex >= paramCount) + { + var location = attr.ApplicationSyntaxReference?.GetSyntax().GetLocation() + ?? method.Locations[0]; + + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.SqlExpressionArgumentCountMismatch, + location, + maxIndex, // {0} – the out-of-range index referenced + paramCount, // {1} – how many parameters the method has + paramCount - 1)); // {2} – the maximum valid index + } + } + } + } +} diff --git a/src/EntityFrameworkCore.Projectables/EntityFrameworkCore.Projectables.csproj b/src/EntityFrameworkCore.Projectables/EntityFrameworkCore.Projectables.csproj index 4be6e6b..5b47004 100644 --- a/src/EntityFrameworkCore.Projectables/EntityFrameworkCore.Projectables.csproj +++ b/src/EntityFrameworkCore.Projectables/EntityFrameworkCore.Projectables.csproj @@ -5,6 +5,7 @@ + diff --git a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/ProjectionOptionsExtension.cs b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/ProjectionOptionsExtension.cs index fbfa4be..a65d749 100644 --- a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/ProjectionOptionsExtension.cs +++ b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/ProjectionOptionsExtension.cs @@ -41,6 +41,9 @@ public void ApplyServices(IServiceCollection services) // Register a convention that will ignore properties marked with the ProjectableAttribute services.AddScoped(); + // Register the translator plugin so that [SqlExpression]-decorated methods are translated + services.AddScoped(); + static object CreateTargetInstance(IServiceProvider services, ServiceDescriptor descriptor) { if (descriptor.ImplementationInstance is not null) diff --git a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/SqlExpressionMethodCallTranslator.cs b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/SqlExpressionMethodCallTranslator.cs new file mode 100644 index 0000000..cddadd5 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/SqlExpressionMethodCallTranslator.cs @@ -0,0 +1,152 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.RegularExpressions; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace EntityFrameworkCore.Projectables.Infrastructure.Internal +{ + /// + /// Translates calls to methods decorated with into + /// the corresponding SQL expressions. + /// + public class SqlExpressionMethodCallTranslator : IMethodCallTranslator + { + // Matches FUNCNAME(content) — captures function name and full argument list + private static readonly Regex FunctionCallPattern = + new Regex(@"^(\w[\w.]*)\s*\((.+)\)\s*$", RegexOptions.Compiled | RegexOptions.Singleline); + + // Matches a standalone {N} placeholder (the entire token) + private static readonly Regex StandaloneArgumentPlaceholderPattern = + new Regex(@"^\{(\d+)\}$", RegexOptions.Compiled); + + private readonly ISqlExpressionFactory _sqlExpressionFactory; + private readonly string? _providerName; + + public SqlExpressionMethodCallTranslator(ISqlExpressionFactory sqlExpressionFactory, string? providerName = null) + { + _sqlExpressionFactory = sqlExpressionFactory; + _providerName = providerName; + } + + /// + public SqlExpression? Translate( + SqlExpression? instance, + MethodInfo method, + IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + var sqlExpressionAttrs = method.GetCustomAttributes().ToArray(); + if (sqlExpressionAttrs.Length == 0) + return null; + + // Prefer an attribute whose Configuration matches the current provider name. + SqlExpressionAttribute? selectedAttr = null; + if (_providerName != null) + { + selectedAttr = sqlExpressionAttrs.FirstOrDefault(a => + a.Configuration != null && + _providerName.Contains(a.Configuration, StringComparison.OrdinalIgnoreCase)); + } + + // Fall back to an attribute without a Configuration (provider-agnostic). + selectedAttr ??= sqlExpressionAttrs.FirstOrDefault(a => a.Configuration is null); + + if (selectedAttr is null) + return null; + + return TranslateTemplate(selectedAttr.Sql, arguments, method.ReturnType); + } + + private SqlExpression? TranslateTemplate( + string template, + IReadOnlyList arguments, + Type returnType) + { + var match = FunctionCallPattern.Match(template.Trim()); + if (!match.Success) + return null; + + var functionName = match.Groups[1].Value; + var argsSection = match.Groups[2].Value; + + var sqlArgs = new List(); + var nullPropagation = new List(); + + foreach (var token in SplitArguments(argsSection)) + { + var t = token.Trim(); + var placeholderMatch = StandaloneArgumentPlaceholderPattern.Match(t); + if (placeholderMatch.Success) + { + var index = int.Parse(placeholderMatch.Groups[1].Value); + if (index >= arguments.Count) + { + throw new InvalidOperationException( + $"SQL template '{template}' references argument {{{index}}} but the method only has {arguments.Count} argument(s). Valid indices are 0 to {arguments.Count - 1}."); + } + sqlArgs.Add(arguments[index]); + nullPropagation.Add(true); + } + else + { + // Literal SQL fragment (e.g. '%Y' in STRFTIME('%Y', {0})) + sqlArgs.Add(_sqlExpressionFactory.Fragment(t)); + nullPropagation.Add(false); + } + } + + return _sqlExpressionFactory.Function( + functionName, + sqlArgs, + nullable: true, + argumentsPropagateNullability: nullPropagation, + returnType); + } + + /// + /// Splits a SQL argument list string on top-level commas, respecting + /// single-quoted string literals and nested parentheses. + /// + private static IEnumerable SplitArguments(string args) + { + var depth = 0; + var inSingleQuote = false; + var start = 0; + + for (var i = 0; i < args.Length; i++) + { + var c = args[i]; + + if (c == '\'' && !inSingleQuote) + { + inSingleQuote = true; + } + else if (c == '\'' && inSingleQuote) + { + // Handle escaped single quotes ('') + if (i + 1 < args.Length && args[i + 1] == '\'') + i++; + else + inSingleQuote = false; + } + else if (!inSingleQuote) + { + if (c == '(') depth++; + else if (c == ')') depth--; + else if (c == ',' && depth == 0) + { + yield return args[start..i]; + start = i + 1; + } + } + } + + yield return args[start..]; + } + } +} diff --git a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/SqlExpressionMethodCallTranslatorPlugin.cs b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/SqlExpressionMethodCallTranslatorPlugin.cs new file mode 100644 index 0000000..cc5425e --- /dev/null +++ b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/SqlExpressionMethodCallTranslatorPlugin.cs @@ -0,0 +1,26 @@ +using System.Collections.Generic; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Query; + +namespace EntityFrameworkCore.Projectables.Infrastructure.Internal +{ + /// + /// Registers the with EF Core's method call + /// translation pipeline so that methods decorated with + /// are translated to the corresponding SQL expressions. + /// + public class SqlExpressionMethodCallTranslatorPlugin : IMethodCallTranslatorPlugin + { + public SqlExpressionMethodCallTranslatorPlugin(ISqlExpressionFactory sqlExpressionFactory, ICurrentDbContext currentDbContext) + { + var providerName = currentDbContext.Context.Database.ProviderName; + Translators = new IMethodCallTranslator[] + { + new SqlExpressionMethodCallTranslator(sqlExpressionFactory, providerName) + }; + } + + /// + public IEnumerable Translators { get; } + } +} diff --git a/src/EntityFrameworkCore.Projectables/SqlExpressionAttribute.cs b/src/EntityFrameworkCore.Projectables/SqlExpressionAttribute.cs new file mode 100644 index 0000000..d5ff7a0 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables/SqlExpressionAttribute.cs @@ -0,0 +1,59 @@ +using System; +using System.Diagnostics.CodeAnalysis; + +namespace EntityFrameworkCore.Projectables +{ + /// + /// Decorates a static method with a SQL template string that will be used to translate + /// the method call into a SQL expression when used in a LINQ query against EF Core. + /// Use positional placeholders {0}, {1}, etc. to refer to the method arguments. + /// Multiple instances of this attribute may be applied to the same method, each with a + /// different value, to provide provider-specific SQL expressions. + /// + /// + /// + /// [SqlExpression("SOUNDEX({0})")] + /// public static string Soundex(string value) => throw new NotImplementedException(); + /// + /// [SqlExpression("COALESCE({0}, {1})")] + /// public static string Coalesce(string value, string fallback) => throw new NotImplementedException(); + /// + /// [SqlExpression("STRFTIME('%Y', {0})", Configuration = "Sqlite")] + /// [SqlExpression("YEAR({0})", Configuration = "SqlServer")] + /// [SqlExpression("EXTRACT(YEAR FROM {0})", Configuration = "Npgsql")] + /// public static int Year(DateTime date) => throw new NotImplementedException(); + /// + /// + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + public sealed class SqlExpressionAttribute : Attribute + { + /// + /// Initializes a new instance of with the given SQL template. + /// + /// + /// The SQL template. Use {0}, {1}, etc. as positional placeholders for method arguments. + /// + public SqlExpressionAttribute([StringSyntax("sql")] string sql) + { + Sql = sql; + } + + /// + /// The SQL template string with positional argument placeholders ({0}, {1}, etc.). + /// + public string Sql { get; } + + /// + /// When true (the default), the method can only be evaluated server-side and must + /// throw in its body. + /// + public bool ServerSideOnly { get; set; } = true; + + /// + /// When set, this attribute only applies when the database provider name contains this value + /// (e.g. "SqlServer", "Sqlite", "Npgsql"). + /// When null (the default), the attribute acts as a fallback for any provider. + /// + public string? Configuration { get; set; } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/Helpers/SqliteSampleDbContext.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/Helpers/SqliteSampleDbContext.cs new file mode 100644 index 0000000..fe4b11c --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/Helpers/SqliteSampleDbContext.cs @@ -0,0 +1,20 @@ +using EntityFrameworkCore.Projectables.Infrastructure; +using Microsoft.EntityFrameworkCore; + +namespace EntityFrameworkCore.Projectables.FunctionalTests.Helpers +{ + public class SqliteSampleDbContext : DbContext + where TEntity : class + { + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) + { + optionsBuilder.UseSqlite("Data Source=:memory:"); + optionsBuilder.UseProjectables(); + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity(); + } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithExtensionMethodStrftimeOnSqlite.DotNet10_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithExtensionMethodStrftimeOnSqlite.DotNet10_0.verified.txt new file mode 100644 index 0000000..ad20e4b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithExtensionMethodStrftimeOnSqlite.DotNet10_0.verified.txt @@ -0,0 +1,2 @@ +SELECT STRFTIME('%Y', "d"."CreatedAt") +FROM "DateEntity" AS "d" \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithExtensionMethodStrftimeOnSqlite.DotNet9_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithExtensionMethodStrftimeOnSqlite.DotNet9_0.verified.txt new file mode 100644 index 0000000..ad20e4b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithExtensionMethodStrftimeOnSqlite.DotNet9_0.verified.txt @@ -0,0 +1,2 @@ +SELECT STRFTIME('%Y', "d"."CreatedAt") +FROM "DateEntity" AS "d" \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithExtensionMethodStrftimeOnSqlite.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithExtensionMethodStrftimeOnSqlite.verified.txt new file mode 100644 index 0000000..ad20e4b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithExtensionMethodStrftimeOnSqlite.verified.txt @@ -0,0 +1,2 @@ +SELECT STRFTIME('%Y', "d"."CreatedAt") +FROM "DateEntity" AS "d" \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithFallbackSqlExpression.DotNet10_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithFallbackSqlExpression.DotNet10_0.verified.txt new file mode 100644 index 0000000..0bf9b6b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithFallbackSqlExpression.DotNet10_0.verified.txt @@ -0,0 +1,2 @@ +SELECT YEAR([d].[CreatedAt]) +FROM [DateEntity] AS [d] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithFallbackSqlExpression.DotNet9_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithFallbackSqlExpression.DotNet9_0.verified.txt new file mode 100644 index 0000000..0bf9b6b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithFallbackSqlExpression.DotNet9_0.verified.txt @@ -0,0 +1,2 @@ +SELECT YEAR([d].[CreatedAt]) +FROM [DateEntity] AS [d] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithFallbackSqlExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithFallbackSqlExpression.verified.txt new file mode 100644 index 0000000..0bf9b6b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithFallbackSqlExpression.verified.txt @@ -0,0 +1,2 @@ +SELECT YEAR([d].[CreatedAt]) +FROM [DateEntity] AS [d] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithProviderSpecificSqlExpression.DotNet10_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithProviderSpecificSqlExpression.DotNet10_0.verified.txt new file mode 100644 index 0000000..0bf9b6b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithProviderSpecificSqlExpression.DotNet10_0.verified.txt @@ -0,0 +1,2 @@ +SELECT YEAR([d].[CreatedAt]) +FROM [DateEntity] AS [d] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithProviderSpecificSqlExpression.DotNet9_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithProviderSpecificSqlExpression.DotNet9_0.verified.txt new file mode 100644 index 0000000..0bf9b6b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithProviderSpecificSqlExpression.DotNet9_0.verified.txt @@ -0,0 +1,2 @@ +SELECT YEAR([d].[CreatedAt]) +FROM [DateEntity] AS [d] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithProviderSpecificSqlExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithProviderSpecificSqlExpression.verified.txt new file mode 100644 index 0000000..0bf9b6b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithProviderSpecificSqlExpression.verified.txt @@ -0,0 +1,2 @@ +SELECT YEAR([d].[CreatedAt]) +FROM [DateEntity] AS [d] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithSqlExpressionCoalesce.DotNet10_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithSqlExpressionCoalesce.DotNet10_0.verified.txt new file mode 100644 index 0000000..fb7617b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithSqlExpressionCoalesce.DotNet10_0.verified.txt @@ -0,0 +1,2 @@ +SELECT COALESCE([e].[NickName], [e].[Name]) +FROM [Entity] AS [e] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithSqlExpressionCoalesce.DotNet9_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithSqlExpressionCoalesce.DotNet9_0.verified.txt new file mode 100644 index 0000000..fb7617b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithSqlExpressionCoalesce.DotNet9_0.verified.txt @@ -0,0 +1,2 @@ +SELECT COALESCE([e].[NickName], [e].[Name]) +FROM [Entity] AS [e] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithSqlExpressionCoalesce.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithSqlExpressionCoalesce.verified.txt new file mode 100644 index 0000000..fb7617b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithSqlExpressionCoalesce.verified.txt @@ -0,0 +1,2 @@ +SELECT COALESCE([e].[NickName], [e].[Name]) +FROM [Entity] AS [e] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithStrftimeOnSqlite.DotNet10_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithStrftimeOnSqlite.DotNet10_0.verified.txt new file mode 100644 index 0000000..ad20e4b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithStrftimeOnSqlite.DotNet10_0.verified.txt @@ -0,0 +1,2 @@ +SELECT STRFTIME('%Y', "d"."CreatedAt") +FROM "DateEntity" AS "d" \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithStrftimeOnSqlite.DotNet9_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithStrftimeOnSqlite.DotNet9_0.verified.txt new file mode 100644 index 0000000..ad20e4b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithStrftimeOnSqlite.DotNet9_0.verified.txt @@ -0,0 +1,2 @@ +SELECT STRFTIME('%Y', "d"."CreatedAt") +FROM "DateEntity" AS "d" \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithStrftimeOnSqlite.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithStrftimeOnSqlite.verified.txt new file mode 100644 index 0000000..ad20e4b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.SelectWithStrftimeOnSqlite.verified.txt @@ -0,0 +1,2 @@ +SELECT STRFTIME('%Y', "d"."CreatedAt") +FROM "DateEntity" AS "d" \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithExtensionMethodSqlExpression.DotNet10_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithExtensionMethodSqlExpression.DotNet10_0.verified.txt new file mode 100644 index 0000000..02dbcc6 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithExtensionMethodSqlExpression.DotNet10_0.verified.txt @@ -0,0 +1,3 @@ +SELECT [e].[Id], [e].[Name], [e].[NickName] +FROM [Entity] AS [e] +WHERE UPPER([e].[Name]) = N'ALICE' \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithExtensionMethodSqlExpression.DotNet9_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithExtensionMethodSqlExpression.DotNet9_0.verified.txt new file mode 100644 index 0000000..02dbcc6 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithExtensionMethodSqlExpression.DotNet9_0.verified.txt @@ -0,0 +1,3 @@ +SELECT [e].[Id], [e].[Name], [e].[NickName] +FROM [Entity] AS [e] +WHERE UPPER([e].[Name]) = N'ALICE' \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithExtensionMethodSqlExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithExtensionMethodSqlExpression.verified.txt new file mode 100644 index 0000000..02dbcc6 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithExtensionMethodSqlExpression.verified.txt @@ -0,0 +1,3 @@ +SELECT [e].[Id], [e].[Name], [e].[NickName] +FROM [Entity] AS [e] +WHERE UPPER([e].[Name]) = N'ALICE' \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithSqlExpressionUpper.DotNet10_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithSqlExpressionUpper.DotNet10_0.verified.txt new file mode 100644 index 0000000..02dbcc6 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithSqlExpressionUpper.DotNet10_0.verified.txt @@ -0,0 +1,3 @@ +SELECT [e].[Id], [e].[Name], [e].[NickName] +FROM [Entity] AS [e] +WHERE UPPER([e].[Name]) = N'ALICE' \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithSqlExpressionUpper.DotNet9_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithSqlExpressionUpper.DotNet9_0.verified.txt new file mode 100644 index 0000000..02dbcc6 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithSqlExpressionUpper.DotNet9_0.verified.txt @@ -0,0 +1,3 @@ +SELECT [e].[Id], [e].[Name], [e].[NickName] +FROM [Entity] AS [e] +WHERE UPPER([e].[Name]) = N'ALICE' \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithSqlExpressionUpper.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithSqlExpressionUpper.verified.txt new file mode 100644 index 0000000..02dbcc6 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.WhereWithSqlExpressionUpper.verified.txt @@ -0,0 +1,3 @@ +SELECT [e].[Id], [e].[Name], [e].[NickName] +FROM [Entity] AS [e] +WHERE UPPER([e].[Name]) = N'ALICE' \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.cs new file mode 100644 index 0000000..0ea84e2 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/SqlExpressionTests.cs @@ -0,0 +1,148 @@ +using System; +using System.Linq; +using System.Threading.Tasks; +using EntityFrameworkCore.Projectables.FunctionalTests.Helpers; +using Microsoft.EntityFrameworkCore; +using VerifyXunit; +using Xunit; + +namespace EntityFrameworkCore.Projectables.FunctionalTests +{ + /// + /// Extension-method-style SQL functions – same SQL templates as , + /// but declared with a this parameter so they can be called with instance syntax. + /// Must be a top-level (non-nested) static class because C# requires extension methods there. + /// + public static class SqlExtensionFunctions + { + [SqlExpression("UPPER({0})")] + public static string Upper(this string value) => throw new NotImplementedException(); + + [SqlExpression("STRFTIME('%Y', {0})", Configuration = "Sqlite")] + [SqlExpression("YEAR({0})", Configuration = "SqlServer")] + public static int Year(this DateTime date) => throw new NotImplementedException(); + } + + [UsesVerify] + public class SqlExpressionTests + { + public static class Functions + { + [SqlExpression("UPPER({0})")] + public static string Upper(string value) => throw new NotImplementedException(); + + [SqlExpression("COALESCE({0}, {1})")] + public static string Coalesce(string? value, string? fallback) => throw new NotImplementedException(); + + [SqlExpression("STRFTIME('%Y', {0})", Configuration = "Sqlite")] + [SqlExpression("YEAR({0})", Configuration = "SqlServer")] + [SqlExpression("EXTRACT(YEAR FROM {0})", Configuration = "Npgsql")] + public static int Year(DateTime date) => throw new NotImplementedException(); + + [SqlExpression("GENERIC_YEAR({0})")] + [SqlExpression("YEAR({0})", Configuration = "SqlServer")] + public static int YearWithFallback(DateTime date) => throw new NotImplementedException(); + } + + public record Entity + { + public int Id { get; set; } + public string Name { get; set; } = ""; + public string? NickName { get; set; } + } + + public record DateEntity + { + public int Id { get; set; } + public DateTime CreatedAt { get; set; } + } + + [Fact] + public Task WhereWithSqlExpressionUpper() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set() + .Where(x => Functions.Upper(x.Name) == "ALICE"); + + return Verifier.Verify(query.ToQueryString()); + } + + [Fact] + public Task SelectWithSqlExpressionCoalesce() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set() + .Select(x => Functions.Coalesce(x.NickName, x.Name)); + + return Verifier.Verify(query.ToQueryString()); + } + + [Fact] + public Task SelectWithProviderSpecificSqlExpression() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set() + .Select(x => Functions.Year(x.CreatedAt)); + + return Verifier.Verify(query.ToQueryString()); + } + + [Fact] + public Task SelectWithFallbackSqlExpression() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set() + .Select(x => Functions.YearWithFallback(x.CreatedAt)); + + return Verifier.Verify(query.ToQueryString()); + } + + /// + /// Verifies that is translated correctly when + /// called with extension-method syntax (x.Name.Upper()). + /// + [Fact] + public Task WhereWithExtensionMethodSqlExpression() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set() + .Where(x => x.Name.Upper() == "ALICE"); + + return Verifier.Verify(query.ToQueryString()); + } + + /// + /// Verifies that a provider-specific template (STRFTIME('%Y', {0})) with a + /// literal SQL fragment mixed into the argument list is translated correctly on SQLite. + /// + [Fact] + public Task SelectWithStrftimeOnSqlite() + { + using var dbContext = new SqliteSampleDbContext(); + + var query = dbContext.Set() + .Select(x => Functions.Year(x.CreatedAt)); + + return Verifier.Verify(query.ToQueryString()); + } + + /// + /// Verifies the same scenario via extension-method syntax on SQLite. + /// + [Fact] + public Task SelectWithExtensionMethodStrftimeOnSqlite() + { + using var dbContext = new SqliteSampleDbContext(); + + var query = dbContext.Set() + .Select(x => x.CreatedAt.Year()); + + return Verifier.Verify(query.ToQueryString()); + } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/SqlExpressionAnalyzerTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/SqlExpressionAnalyzerTests.cs new file mode 100644 index 0000000..6db70ae --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/SqlExpressionAnalyzerTests.cs @@ -0,0 +1,199 @@ +using System.Collections.Immutable; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; +using Xunit; +using Xunit.Abstractions; + +namespace EntityFrameworkCore.Projectables.Generator.Tests +{ + public class SqlExpressionAnalyzerTests + { + readonly ITestOutputHelper _testOutputHelper; + + public SqlExpressionAnalyzerTests(ITestOutputHelper testOutputHelper) + { + _testOutputHelper = testOutputHelper; + } + + // ------------------------------------------------------------------ helpers + + private Compilation CreateCompilation(string source) + { + var references = Basic.Reference.Assemblies. +#if NET10_0 + Net100 +#elif NET9_0 + Net90 +#elif NET8_0 + Net80 +#endif + .References.All.ToList(); + + // Add abstractions assembly (ProjectableAttribute) + references.Add(MetadataReference.CreateFromFile(typeof(ProjectableAttribute).Assembly.Location)); + // Add main project assembly (SqlExpressionAttribute) + references.Add(MetadataReference.CreateFromFile(typeof(SqlExpressionAttribute).Assembly.Location)); + + return CSharpCompilation.Create("compilation", + new[] { CSharpSyntaxTree.ParseText(source) }, + references, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + } + + private async Task> RunAnalyzerAsync(Compilation compilation) + { + var analyzer = new SqlExpressionAnalyzer(); + var withAnalyzers = compilation.WithAnalyzers( + ImmutableArray.Create(analyzer)); + return await withAnalyzers.GetAnalyzerDiagnosticsAsync(); + } + + private ImmutableArray Efp0003Diagnostics(ImmutableArray all) + => all.Where(d => d.Id == "EFP0003").ToImmutableArray(); + + // ------------------------------------------------------------------ tests + + [Fact] + public async Task NoDiagnostic_WhenArgCountMatches_SingleArg() + { + var compilation = CreateCompilation(@" +using System; +using EntityFrameworkCore.Projectables; +public static class Fns +{ + [SqlExpression(""YEAR({0})"")] + public static int Year(DateTime date) => throw new NotImplementedException(); +}"); + var diagnostics = Efp0003Diagnostics(await RunAnalyzerAsync(compilation)); + Assert.Empty(diagnostics); + } + + [Fact] + public async Task NoDiagnostic_WhenArgCountMatches_MultipleArgs() + { + var compilation = CreateCompilation(@" +using System; +using EntityFrameworkCore.Projectables; +public static class Fns +{ + [SqlExpression(""COALESCE({0}, {1})"")] + public static string Coalesce(string a, string b) => throw new NotImplementedException(); +}"); + var diagnostics = Efp0003Diagnostics(await RunAnalyzerAsync(compilation)); + Assert.Empty(diagnostics); + } + + [Fact] + public async Task NoDiagnostic_WhenExtensionMethodWithCorrectArgCount() + { + var compilation = CreateCompilation(@" +using System; +using EntityFrameworkCore.Projectables; +public static class Fns +{ + [SqlExpression(""YEAR({0})"")] + public static int Year(this DateTime date) => throw new NotImplementedException(); +}"); + var diagnostics = Efp0003Diagnostics(await RunAnalyzerAsync(compilation)); + Assert.Empty(diagnostics); + } + + [Fact] + public async Task NoDiagnostic_WhenNoPlaceholders() + { + var compilation = CreateCompilation(@" +using System; +using EntityFrameworkCore.Projectables; +public static class Fns +{ + [SqlExpression(""GETDATE()"")] + public static DateTime Now() => throw new NotImplementedException(); +}"); + var diagnostics = Efp0003Diagnostics(await RunAnalyzerAsync(compilation)); + Assert.Empty(diagnostics); + } + + [Fact] + public async Task NoDiagnostic_WhenMultipleConfigurations_AllValid() + { + var compilation = CreateCompilation(@" +using System; +using EntityFrameworkCore.Projectables; +public static class Fns +{ + [SqlExpression(""YEAR({0})"", Configuration = ""SqlServer"")] + [SqlExpression(""STRFTIME('%Y', {0})"", Configuration = ""Sqlite"")] + public static int Year(DateTime date) => throw new NotImplementedException(); +}"); + var diagnostics = Efp0003Diagnostics(await RunAnalyzerAsync(compilation)); + Assert.Empty(diagnostics); + } + + [Fact] + public async Task ReportsDiagnostic_WhenIndexExceedsParamCount_NoParams() + { + var compilation = CreateCompilation(@" +using System; +using EntityFrameworkCore.Projectables; +public static class Fns +{ + [SqlExpression(""YEAR({0})"")] + public static int Year() => throw new NotImplementedException(); +}"); + var diagnostics = Efp0003Diagnostics(await RunAnalyzerAsync(compilation)); + Assert.Single(diagnostics); + } + + [Fact] + public async Task ReportsDiagnostic_WhenIndexExceedsParamCount_OneParam() + { + var compilation = CreateCompilation(@" +using System; +using EntityFrameworkCore.Projectables; +public static class Fns +{ + [SqlExpression(""COALESCE({0}, {1})"")] + public static string Coalesce(string a) => throw new NotImplementedException(); +}"); + var diagnostics = Efp0003Diagnostics(await RunAnalyzerAsync(compilation)); + Assert.Single(diagnostics); + } + + [Fact] + public async Task ReportsDiagnostic_ForEachInvalidAttribute() + { + var compilation = CreateCompilation(@" +using System; +using EntityFrameworkCore.Projectables; +public static class Fns +{ + [SqlExpression(""YEAR({0})"", Configuration = ""SqlServer"")] + [SqlExpression(""YEAR({0})"", Configuration = ""Sqlite"")] + public static int Year() => throw new NotImplementedException(); +}"); + var diagnostics = Efp0003Diagnostics(await RunAnalyzerAsync(compilation)); + // One diagnostic per invalid attribute + Assert.Equal(2, diagnostics.Length); + } + + [Fact] + public async Task ReportsDiagnostic_OnlyForInvalidAttribute_InMixedList() + { + var compilation = CreateCompilation(@" +using System; +using EntityFrameworkCore.Projectables; +public static class Fns +{ + [SqlExpression(""YEAR({0})"", Configuration = ""SqlServer"")] + [SqlExpression(""YEAR({0})"", Configuration = ""Sqlite"")] + public static int Year(DateTime date) => throw new NotImplementedException(); +}"); + // All attributes are valid here (1 param, max index 0) + var diagnostics = Efp0003Diagnostics(await RunAnalyzerAsync(compilation)); + Assert.Empty(diagnostics); + } + } +}