Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 25 additions & 33 deletions src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,8 @@ private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo me
return true;
}

private static int? GetOverridingMethodIndex(this MethodInfo methodInfo, MethodInfo[]? allDerivedMethods)
{
if (allDerivedMethods is { Length: > 0 })
{
var baseDefinition = methodInfo.GetBaseDefinition();
for (var i = 0; i < allDerivedMethods.Length; i++)
{
var derivedMethodInfo = allDerivedMethods[i];
if (derivedMethodInfo.GetBaseDefinition() == baseDefinition)
{
return i;
}
}
}

return null;
}
private static bool IsOverridingMethodOf(this MethodInfo methodInfo, MethodInfo baseDefinition)
=> methodInfo.GetBaseDefinition() == baseDefinition;

public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
{
Expand All @@ -81,31 +66,38 @@ public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo m

var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);

return methodInfo.GetOverridingMethodIndex(derivedMethods) is { } i
? derivedMethods[i]
// No derived methods were found. Return the original methodInfo
: methodInfo;
MethodInfo? overridingMethod = null;
if (derivedMethods is { Length: > 0 })
{
var baseDefinition = methodInfo.GetBaseDefinition();
overridingMethod = derivedMethods.FirstOrDefault(derivedMethodInfo
=> derivedMethodInfo.IsOverridingMethodOf(baseDefinition));
}

return overridingMethod ?? methodInfo; // If no derived methods were found, return the original methodInfo
}

public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo)
{
var accessor = propertyInfo.GetAccessors(true)[0];

if (!derivedType.CanHaveOverridingMethod(accessor))
var accessor = propertyInfo.GetAccessors(true).FirstOrDefault(derivedType.CanHaveOverridingMethod);
if (accessor is null)
{
return propertyInfo;
}

var isGetAccessor = propertyInfo.GetMethod == accessor;

var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
var derivedPropertyMethods = derivedProperties
.Select((Func<PropertyInfo, MethodInfo?>)
(propertyInfo.GetMethod == accessor ? p => p.GetMethod : p => p.SetMethod))
.OfType<MethodInfo>().ToArray();

return accessor.GetOverridingMethodIndex(derivedPropertyMethods) is { } i
? derivedProperties[i]
// No derived methods were found. Return the original methodInfo
: propertyInfo;

PropertyInfo? overridingProperty = null;
if (derivedProperties is { Length: > 0 })
{
var baseDefinition = accessor.GetBaseDefinition();
overridingProperty = derivedProperties.FirstOrDefault(p
=> (isGetAccessor ? p.GetMethod : p.SetMethod)?.IsOverridingMethodOf(baseDefinition) == true);
}

return overridingProperty ?? propertyInfo; // If no derived methods were found, return the original methodInfo
}

public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ static object CreateTargetInstance(IServiceProvider services, ServiceDescriptor
var targetDescriptor = services.FirstOrDefault(x => x.ServiceType == typeof(IQueryCompiler));
if (targetDescriptor is null)
{
throw new InvalidOperationException("No QueryProvider is configured yet. Please make sure to configure a database provider first"); ;
throw new InvalidOperationException("No QueryProvider is configured yet. Please make sure to configure a database provider first");
}

var decoratorObjectFactory = ActivatorUtilities.CreateFactory(typeof(CustomQueryCompiler), new[] { targetDescriptor.ServiceType });
Expand All @@ -70,7 +70,7 @@ static object CreateTargetInstance(IServiceProvider services, ServiceDescriptor
var targetDescriptor = services.FirstOrDefault(x => x.ServiceType == typeof(IQueryTranslationPreprocessorFactory));
if (targetDescriptor is null)
{
throw new InvalidOperationException("No QueryTranslationPreprocessorFactory is configured yet. Please make sure to configure a database provider first"); ;
throw new InvalidOperationException("No QueryTranslationPreprocessorFactory is configured yet. Please make sure to configure a database provider first");
}

var decoratorObjectFactory = ActivatorUtilities.CreateFactory(typeof(CustomQueryTranslationPreprocessorFactory), new[] { targetDescriptor.ServiceType });
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT 4
FROM [Concrete] AS [c]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT [c].[Id]
FROM [BaseProvider] AS [b]
INNER JOIN [Concrete] AS [c] ON [b].[Id] = [c].[BaseProviderId]
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,20 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
[UsesVerify]
public class InheritedModelTests
{
public interface IBaseProvider<TBase>
{
ICollection<TBase> Bases { get; set; }
}

public class BaseProvider : IBaseProvider<Concrete>
{
public int Id { get; set; }
public ICollection<Concrete> Bases { get; set; }
}

public interface IBase
{
int Id { get; }
int ComputedProperty { get; }
int ComputedMethod();
}
Expand Down Expand Up @@ -117,6 +129,26 @@ public Task ProjectOverImplementedMethod()

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task ProjectOverProvider()
{
using var dbContext = new SampleDbContext<BaseProvider>();

var query = dbContext.Set<BaseProvider>().AllBases<BaseProvider, Concrete>();

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task ProjectOverExtensionMethod()
{
using var dbContext = new SampleDbContext<Concrete>();

var query = dbContext.Set<Concrete>().Select(c => c.ComputedPropertyPlusMethod());

return Verifier.Verify(query.ToQueryString());
}
}

public static class ModelExtensions
Expand All @@ -128,5 +160,15 @@ public static IQueryable<int> SelectComputedProperty<TConcrete>(this IQueryable<
public static IQueryable<int> SelectComputedMethod<TConcrete>(this IQueryable<TConcrete> concretes)
where TConcrete : InheritedModelTests.IBase
=> concretes.Select(x => x.ComputedMethod());

public static IQueryable<int> AllBases<TProvider, TBase>(this IQueryable<TProvider> concretes)
where TProvider : InheritedModelTests.IBaseProvider<TBase>
where TBase : InheritedModelTests.IBase
=> concretes.SelectMany(x => x.Bases).Select(x => x.Id);

[Projectable]
public static int ComputedPropertyPlusMethod<TConcrete>(this TConcrete concrete)
where TConcrete : InheritedModelTests.IBase
=> concrete.ComputedProperty + concrete.ComputedMethod();
}
}