Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Setup .NET
uses: actions/setup-dotnet@v1
with:
dotnet-version: 6.0.x
dotnet-version: 7.0.x
- name: Restore dependencies
run: dotnet restore
- name: Build
Expand All @@ -46,7 +46,7 @@ jobs:
- name: Setup .NET
uses: actions/setup-dotnet@v1
with:
dotnet-version: 6.0.x
dotnet-version: 7.0.x
- name: Pack
run: |
dotnet pack -v normal -c Debug --include-symbols --include-source -p:PackageVersion=2.0.0-pre-$GITHUB_RUN_ID -o nupkg
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Setup .NET Core
uses: actions/setup-dotnet@v1
with:
dotnet-version: 6.0.x
dotnet-version: 7.0.x
include-prerelease: True
- name: Create Release NuGet package
run: |
Expand Down
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<LangVersion>9.0</LangVersion>
<LangVersion>11.0</LangVersion>
<Nullable>enable</Nullable>
<EnableNETAnalyzers>true</EnableNETAnalyzers>
<NoWarn>CS1591</NoWarn>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net6.0</TargetFramework>
<TargetFramework>net7.0</TargetFramework>
<IsPackable>false</IsPackable>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="BenchmarkDotNet" Version="0.13.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="6.0.0" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="7.0.0" />
</ItemGroup>

<ItemGroup>
Expand Down
88 changes: 88 additions & 0 deletions src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Reflection.Metadata;
using System.Runtime.CompilerServices;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading.Tasks;

Expand All @@ -21,5 +24,90 @@ public static IEnumerable<Type> GetNestedTypePath(this Type type)

yield return type;
}

public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
{
// We only need to search for virtual instance methods who are not declared on the derivedType
if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual)
{
return methodInfo;
}

if (!derivedType.IsAssignableTo(methodInfo.DeclaringType))
{
throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo));
}

var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);

foreach (var derivedMethodInfo in derivedMethods)
{
if (HasCompatibleSignature(methodInfo, derivedMethodInfo))
{
return derivedMethodInfo;
}
}

// No derived methods were found. Return the original methodInfo
return methodInfo;

static bool HasCompatibleSignature(MethodInfo methodInfo, MethodInfo derivedMethodInfo)
{
if (methodInfo.Name != derivedMethodInfo.Name)
{
return false;
}

var methodParameters = methodInfo.GetParameters();

var derivedMethodParameters = derivedMethodInfo.GetParameters();
if (methodParameters.Length != derivedMethodParameters.Length)
{
return false;
}

// Match all parameters
for (var parameterIndex = 0; parameterIndex < methodParameters.Length; parameterIndex++)
{
var parameter = methodParameters[parameterIndex];
var derivedParameter = derivedMethodParameters[parameterIndex];

if (parameter.ParameterType.IsGenericParameter)
{
if (!derivedParameter.ParameterType.IsGenericParameter)
{
return false;
}
}
else
{
if (parameter.ParameterType != derivedParameter.ParameterType)
{
return false;
}
}
}

// Match the number of generic type arguments
if (methodInfo.IsGenericMethodDefinition)
{
var methodGenericParameters = methodInfo.GetGenericArguments();

if (!derivedMethodInfo.IsGenericMethodDefinition)
{
return false;
}

var derivedGenericArguments = derivedMethodInfo.GetGenericArguments();

if (methodGenericParameters.Length != derivedGenericArguments.Length)
{
return false;
}
}

return true;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
Expand All @@ -7,6 +8,7 @@
using System.Text;
using System.Threading.Tasks;
using System.Xml.Linq;
using EntityFrameworkCore.Projectables.Extensions;

namespace EntityFrameworkCore.Projectables.Services
{
Expand Down Expand Up @@ -39,7 +41,10 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La

protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (TryGetReflectedExpression(node.Method, out var reflectedExpression))
// Get the overriding methodInfo based on te type of the received of this expression
var methodInfo = node.Object?.Type.GetOverridingMethod(node.Method) ?? node.Method;

if (TryGetReflectedExpression(methodInfo, out var reflectedExpression))
{
for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++)
{
Expand Down Expand Up @@ -69,7 +74,12 @@ protected override Expression VisitMethodCall(MethodCallExpression node)

protected override Expression VisitMember(MemberExpression node)
{
if (TryGetReflectedExpression(node.Member, out var reflectedExpression))
var nodeMember = node.Expression switch {
{ Type: { } } => node.Expression.Type.GetMember(node.Member.Name, node.Member.MemberType, BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static)[0],
_ => node.Member
};

if (TryGetReflectedExpression(nodeMember, out var reflectedExpression))
{
if (node.Expression is not null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ public sealed class ProjectionExpressionResolver : IProjectionExpressionResolver
{
public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo)
{
var reflectedType = projectableMemberInfo.ReflectedType ?? throw new InvalidOperationException("Expected a valid type here");
var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(reflectedType.Namespace, reflectedType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name);
var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here");
var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name);

var genericArguments = projectableMemberInfo switch {
MethodInfo methodInfo => methodInfo.GetGenericArguments(),
_ => null
};

var expressionFactoryMethod = reflectedType.Assembly.GetType(generatedContainingTypeName)
var expressionFactoryMethod = declaringType.Assembly.GetType(generatedContainingTypeName)
?.GetMethods()
?.FirstOrDefault();

Expand All @@ -40,20 +40,20 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo

if (useMemberBody is not null)
{
var exprProperty = reflectedType.GetProperty(useMemberBody, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
var exprProperty = declaringType.GetProperty(useMemberBody, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
var lambda = exprProperty?.GetValue(null) as LambdaExpression;

if (lambda is not null)
{
if (projectableMemberInfo is PropertyInfo property &&
lambda.Parameters.Count == 1 &&
lambda.Parameters[0].Type == reflectedType && lambda.ReturnType == property.PropertyType)
lambda.Parameters[0].Type == declaringType && lambda.ReturnType == property.PropertyType)
{
return lambda;
}
else if (projectableMemberInfo is MethodInfo method &&
lambda.Parameters.Count == method.GetParameters().Length + 1 &&
lambda.Parameters.Last().Type == reflectedType &&
lambda.Parameters.Last().Type == declaringType &&
!lambda.Parameters.Zip(method.GetParameters(), (a, b) => a.Type != b.ParameterType).Any())
{
return lambda;
Expand All @@ -62,8 +62,8 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo
}

var fullName = string.Join(".", Enumerable.Empty<string>()
.Concat(new[] { reflectedType.Namespace })
.Concat(reflectedType.GetNestedTypePath().Select(x => x.Name))
.Concat(new[] { declaringType.Namespace })
.Concat(declaringType.GetNestedTypePath().Select(x => x.Name))
.Concat(new[] { projectableMemberInfo.Name }));

throw new InvalidOperationException($"Unable to resolve generated expression for {fullName}.") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<TargetFramework>net7.0</TargetFramework>
<IsPackable>false</IsPackable>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
<CompilerGeneratedFilesOutputPath>$(BaseIntermediateOutputPath)Generated</CompilerGeneratedFilesOutputPath>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT 1 + 1
FROM [Concrete] AS [c]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT 1 + 1
FROM [Concrete] AS [c]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT 1 + 1
FROM [Concrete] AS [c]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT 1 + 1
FROM [Concrete] AS [c]
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations.Schema;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading.Tasks;
using EntityFrameworkCore.Projectables.FunctionalTests.Helpers;
using EntityFrameworkCore.Projectables.Services;
using Microsoft.EntityFrameworkCore;
using ScenarioTests;
using VerifyXunit;
using Xunit;

#nullable disable

namespace EntityFrameworkCore.Projectables.FunctionalTests
{

[UsesVerify]
public class InheritedModelTests
{
public abstract class Base
{
public int Id { get; set; }

[Projectable]
public int ComputedProperty => SampleProperty + 1;

public virtual int SampleProperty => 0;

[Projectable]
public int ComputedMethod() => SampleMethod() + 1;

public virtual int SampleMethod() => 0;
}

public class Concrete : Base
{
[Projectable]
public override int SampleProperty => 1;

[Projectable]
public override int SampleMethod() => 1;
}

public class MoreConcrete : Concrete
{
}

[Fact]
public Task ProjectOverOverriddenPropertyImplementation()
{
using var dbContext = new SampleDbContext<Concrete>();

var query = dbContext.Set<Concrete>()
.Select(x => x.ComputedProperty);

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task ProjectOverInheritedPropertyImplementation()
{
using var dbContext = new SampleDbContext<Concrete>();

var query = dbContext.Set<Concrete>()
.Select(x => x.ComputedProperty);

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task ProjectOverOverriddenMethodImplementation()
{
using var dbContext = new SampleDbContext<Concrete>();

var query = dbContext.Set<Concrete>()
.Select(x => x.ComputedMethod());

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task ProjectOverInheritedMethodImplementation()
{
using var dbContext = new SampleDbContext<Concrete>();

var query = dbContext.Set<Concrete>()
.Select(x => x.ComputedMethod());

return Verifier.Verify(query.ToQueryString());
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<TargetFramework>net7.0</TargetFramework>
<IsPackable>false</IsPackable>
</PropertyGroup>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<TargetFramework>net7.0</TargetFramework>
<IsPackable>false</IsPackable>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>

Expand Down
Loading