Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<Project>
<PropertyGroup>
<TargetFrameworks>net8.0;net10.0</TargetFrameworks>
<TargetFrameworks>net8.0;net9.0;net10.0</TargetFrameworks>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<LangVersion>12.0</LangVersion>
<Nullable>enable</Nullable>
Expand Down
3 changes: 3 additions & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
</PropertyGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net8.0'">
<PackageVersion Include="Microsoft.EntityFrameworkCore" Version="8.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Relational" Version="8.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Sqlite" Version="8.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="8.0.0" />
<PackageVersion Include="Basic.Reference.Assemblies.Net80" Version="1.8.3" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net9.0'">
<PackageVersion Include="Microsoft.EntityFrameworkCore" Version="9.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Relational" Version="9.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Sqlite" Version="9.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="9.0.0" />
<PackageVersion Include="Basic.Reference.Assemblies.Net90" Version="1.8.3" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net10.0'">
<PackageVersion Include="Microsoft.EntityFrameworkCore" Version="10.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Relational" Version="10.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Sqlite" Version="10.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="10.0.0" />
<PackageVersion Include="Basic.Reference.Assemblies.Net100" Version="1.8.3" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
Rule ID | Category | Severity | Notes
--------|----------|----------|--------------------
EFP0002 | Design | Error |
EFP0003 | Design | Error |
8 changes: 8 additions & 0 deletions src/EntityFrameworkCore.Projectables.Generator/Diagnostics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,13 @@ public static class Diagnostics
DiagnosticSeverity.Error,
isEnabledByDefault: true);

public static readonly DiagnosticDescriptor SqlExpressionArgumentCountMismatch = new DiagnosticDescriptor(
id: "EFP0003",
title: "SqlExpression template references out-of-range argument",
messageFormat: "SQL template references argument {{{0}}} but the method only has {1} parameter(s). Valid argument indices range from {{0}} to {{{2}}}.",
category: "Design",
DiagnosticSeverity.Error,
isEnabledByDefault: true);

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using System.Collections.Immutable;
using System.Text.RegularExpressions;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;

namespace EntityFrameworkCore.Projectables.Generator
{
/// <summary>
/// Validates that <c>[SqlExpression]</c> SQL templates do not reference argument indices that
/// are out of range for the decorated method's parameter list.
/// </summary>
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class SqlExpressionAnalyzer : DiagnosticAnalyzer
{
private const string SqlExpressionAttributeFullName = "EntityFrameworkCore.Projectables.SqlExpressionAttribute";

// Matches any {N} placeholder in the SQL template
private static readonly Regex PlaceholderPattern =
new Regex(@"\{(\d+)\}", RegexOptions.Compiled);

public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics =>
ImmutableArray.Create(Diagnostics.SqlExpressionArgumentCountMismatch);

public override void Initialize(AnalysisContext context)
{
context.EnableConcurrentExecution();
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
context.RegisterSymbolAction(AnalyzeMethod, SymbolKind.Method);
}

private static void AnalyzeMethod(SymbolAnalysisContext context)
{
var method = (IMethodSymbol)context.Symbol;

var sqlExprAttrType = context.Compilation.GetTypeByMetadataName(SqlExpressionAttributeFullName);
if (sqlExprAttrType is null)
return;

var paramCount = method.Parameters.Length;

foreach (var attr in method.GetAttributes())
{
if (!SymbolEqualityComparer.Default.Equals(attr.AttributeClass, sqlExprAttrType))
continue;

if (attr.ConstructorArguments.Length == 0)
continue;

var sqlTemplate = attr.ConstructorArguments[0].Value as string;
if (sqlTemplate is null)
continue;

var maxIndex = -1;
foreach (Match m in PlaceholderPattern.Matches(sqlTemplate))
{
var idx = int.Parse(m.Groups[1].Value);
if (idx > maxIndex)
maxIndex = idx;
}

if (maxIndex >= paramCount)
{
var location = attr.ApplicationSyntaxReference?.GetSyntax().GetLocation()
?? method.Locations[0];

context.ReportDiagnostic(Diagnostic.Create(
Diagnostics.SqlExpressionArgumentCountMismatch,
location,
maxIndex, // {0} – the out-of-range index referenced
paramCount, // {1} – how many parameters the method has
paramCount - 1)); // {2} – the maximum valid index
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ public void ApplyServices(IServiceCollection services)
// Register a convention that will ignore properties marked with the ProjectableAttribute
services.AddScoped<IConventionSetPlugin, ProjectablePropertiesNotMappedConventionPlugin>();

// Register the translator plugin so that [SqlExpression]-decorated methods are translated
services.AddScoped<IMethodCallTranslatorPlugin, SqlExpressionMethodCallTranslatorPlugin>();

static object CreateTargetInstance(IServiceProvider services, ServiceDescriptor descriptor)
{
if (descriptor.ImplementationInstance is not null)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace EntityFrameworkCore.Projectables.Infrastructure.Internal
{
/// <summary>
/// Translates calls to methods decorated with <see cref="SqlExpressionAttribute"/> into
/// the corresponding SQL expressions.
/// </summary>
public class SqlExpressionMethodCallTranslator : IMethodCallTranslator
{
// Matches FUNCNAME(content) — captures function name and full argument list
private static readonly Regex FunctionCallPattern =
new Regex(@"^(\w[\w.]*)\s*\((.+)\)\s*$", RegexOptions.Compiled | RegexOptions.Singleline);

// Matches a standalone {N} placeholder (the entire token)
private static readonly Regex StandaloneArgumentPlaceholderPattern =
new Regex(@"^\{(\d+)\}$", RegexOptions.Compiled);

private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly string? _providerName;

public SqlExpressionMethodCallTranslator(ISqlExpressionFactory sqlExpressionFactory, string? providerName = null)
{
_sqlExpressionFactory = sqlExpressionFactory;
_providerName = providerName;
}

/// <inheritdoc />
public SqlExpression? Translate(
SqlExpression? instance,
MethodInfo method,
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
var sqlExpressionAttrs = method.GetCustomAttributes<SqlExpressionAttribute>().ToArray();
if (sqlExpressionAttrs.Length == 0)
return null;

// Prefer an attribute whose Configuration matches the current provider name.
SqlExpressionAttribute? selectedAttr = null;
if (_providerName != null)
{
selectedAttr = sqlExpressionAttrs.FirstOrDefault(a =>
a.Configuration != null &&
_providerName.Contains(a.Configuration, StringComparison.OrdinalIgnoreCase));
}

// Fall back to an attribute without a Configuration (provider-agnostic).
selectedAttr ??= sqlExpressionAttrs.FirstOrDefault(a => a.Configuration is null);

if (selectedAttr is null)
return null;

return TranslateTemplate(selectedAttr.Sql, arguments, method.ReturnType);
}

private SqlExpression? TranslateTemplate(
string template,
IReadOnlyList<SqlExpression> arguments,
Type returnType)
{
var match = FunctionCallPattern.Match(template.Trim());
if (!match.Success)
return null;

var functionName = match.Groups[1].Value;
var argsSection = match.Groups[2].Value;

var sqlArgs = new List<SqlExpression>();
var nullPropagation = new List<bool>();

foreach (var token in SplitArguments(argsSection))
{
var t = token.Trim();
var placeholderMatch = StandaloneArgumentPlaceholderPattern.Match(t);
if (placeholderMatch.Success)
{
var index = int.Parse(placeholderMatch.Groups[1].Value);
if (index >= arguments.Count)
{
throw new InvalidOperationException(
$"SQL template '{template}' references argument {{{index}}} but the method only has {arguments.Count} argument(s). Valid indices are 0 to {arguments.Count - 1}.");
}
sqlArgs.Add(arguments[index]);
nullPropagation.Add(true);
}
else
{
// Literal SQL fragment (e.g. '%Y' in STRFTIME('%Y', {0}))
sqlArgs.Add(_sqlExpressionFactory.Fragment(t));
nullPropagation.Add(false);
}
}

return _sqlExpressionFactory.Function(
functionName,
sqlArgs,
nullable: true,
argumentsPropagateNullability: nullPropagation,
returnType);
}

/// <summary>
/// Splits a SQL argument list string on top-level commas, respecting
/// single-quoted string literals and nested parentheses.
/// </summary>
private static IEnumerable<string> SplitArguments(string args)
{
var depth = 0;
var inSingleQuote = false;
var start = 0;

for (var i = 0; i < args.Length; i++)
{
var c = args[i];

if (c == '\'' && !inSingleQuote)
{
inSingleQuote = true;
}
else if (c == '\'' && inSingleQuote)
{
// Handle escaped single quotes ('')
if (i + 1 < args.Length && args[i + 1] == '\'')
i++;
else
inSingleQuote = false;
}
else if (!inSingleQuote)
{
if (c == '(') depth++;
else if (c == ')') depth--;
else if (c == ',' && depth == 0)
{
yield return args[start..i];
start = i + 1;
}
}
}

yield return args[start..];
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System.Collections.Generic;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Query;

namespace EntityFrameworkCore.Projectables.Infrastructure.Internal
{
/// <summary>
/// Registers the <see cref="SqlExpressionMethodCallTranslator"/> with EF Core's method call
/// translation pipeline so that methods decorated with <see cref="SqlExpressionAttribute"/>
/// are translated to the corresponding SQL expressions.
/// </summary>
public class SqlExpressionMethodCallTranslatorPlugin : IMethodCallTranslatorPlugin
{
public SqlExpressionMethodCallTranslatorPlugin(ISqlExpressionFactory sqlExpressionFactory, ICurrentDbContext currentDbContext)
{
var providerName = currentDbContext.Context.Database.ProviderName;
Translators = new IMethodCallTranslator[]
{
new SqlExpressionMethodCallTranslator(sqlExpressionFactory, providerName)
};
}

/// <inheritdoc />
public IEnumerable<IMethodCallTranslator> Translators { get; }
}
}
59 changes: 59 additions & 0 deletions src/EntityFrameworkCore.Projectables/SqlExpressionAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using System;
using System.Diagnostics.CodeAnalysis;

namespace EntityFrameworkCore.Projectables
{
/// <summary>
/// Decorates a static method with a SQL template string that will be used to translate
/// the method call into a SQL expression when used in a LINQ query against EF Core.
/// Use positional placeholders {0}, {1}, etc. to refer to the method arguments.
/// Multiple instances of this attribute may be applied to the same method, each with a
/// different <see cref="Configuration"/> value, to provide provider-specific SQL expressions.
/// </summary>
/// <example>
/// <code>
/// [SqlExpression("SOUNDEX({0})")]
/// public static string Soundex(string value) => throw new NotImplementedException();
///
/// [SqlExpression("COALESCE({0}, {1})")]
/// public static string Coalesce(string value, string fallback) => throw new NotImplementedException();
///
/// [SqlExpression("STRFTIME('%Y', {0})", Configuration = "Sqlite")]
/// [SqlExpression("YEAR({0})", Configuration = "SqlServer")]
/// [SqlExpression("EXTRACT(YEAR FROM {0})", Configuration = "Npgsql")]
/// public static int Year(DateTime date) => throw new NotImplementedException();
/// </code>
/// </example>
[AttributeUsage(AttributeTargets.Method, AllowMultiple = true)]
public sealed class SqlExpressionAttribute : Attribute
{
/// <summary>
/// Initializes a new instance of <see cref="SqlExpressionAttribute"/> with the given SQL template.
/// </summary>
/// <param name="sql">
/// The SQL template. Use {0}, {1}, etc. as positional placeholders for method arguments.
/// </param>
public SqlExpressionAttribute([StringSyntax("sql")] string sql)
{
Sql = sql;
}

/// <summary>
/// The SQL template string with positional argument placeholders ({0}, {1}, etc.).
/// </summary>
public string Sql { get; }

/// <summary>
/// When <c>true</c> (the default), the method can only be evaluated server-side and must
/// throw <see cref="NotImplementedException"/> in its body.
/// </summary>
public bool ServerSideOnly { get; set; } = true;

/// <summary>
/// When set, this attribute only applies when the database provider name contains this value
/// (e.g. <c>"SqlServer"</c>, <c>"Sqlite"</c>, <c>"Npgsql"</c>).
/// When <c>null</c> (the default), the attribute acts as a fallback for any provider.
/// </summary>
public string? Configuration { get; set; }
}
}
Loading