Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<Project>
<PropertyGroup>
<NoWarn>CS1591;NU5104;CS1573</NoWarn>
<Version>18.1.0</Version>
<Version>19.0.0</Version>
<AssemblyVersion>1.0.0</AssemblyVersion>
<PackageTags>EntityFrameworkCore, EntityFramework, GraphQL</PackageTags>
<SignAssembly>false</SignAssembly>
Expand Down
2 changes: 1 addition & 1 deletion src/GraphQL.EntityFramework/ConnectionConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ static async Task<Connection<TItem>> Range<TSource, TItem>(
var page = list.Skip(skip).Take(take);
QueryLogger.Write(page);
IEnumerable<TItem> result = await page.ToListAsync(cancellation);
result = await filters.ApplyFilter(result, context.UserContext);
result = await filters.ApplyFilter(result, context.UserContext, context.User);

cancellation.ThrowIfCancellationRequested();
return Build(skip, take, count, result);
Expand Down
32 changes: 17 additions & 15 deletions src/GraphQL.EntityFramework/Filters/Filters.cs
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
namespace GraphQL.EntityFramework;
using System.Security.Claims;

namespace GraphQL.EntityFramework;

#region FiltersSignature

public class Filters
{
public delegate bool Filter<in TEntity>(object userContext, TEntity input)
public delegate bool Filter<in TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity input)
where TEntity : class;

public delegate Task<bool> AsyncFilter<in TEntity>(object userContext, TEntity input)
public delegate Task<bool> AsyncFilter<in TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity input)
where TEntity : class;

#endregion

public void Add<TEntity>(Filter<TEntity> filter)
where TEntity : class =>
funcs[typeof(TEntity)] =
(context, item) =>
(userContext, userPrincipal, item) =>
{
try
{
return Task.FromResult(filter(context, (TEntity) item));
return Task.FromResult(filter(userContext, userPrincipal, (TEntity) item));
}
catch (Exception exception)
{
Expand All @@ -30,23 +32,23 @@ public void Add<TEntity>(Filter<TEntity> filter)
public void Add<TEntity>(AsyncFilter<TEntity> filter)
where TEntity : class =>
funcs[typeof(TEntity)] =
async (context, item) =>
async (userContext, userPrincipal, item) =>
{
try
{
return await filter(context, (TEntity) item);
return await filter(userContext, userPrincipal, (TEntity) item);
}
catch (Exception exception)
{
throw new($"Failed to execute filter. {nameof(TEntity)}: {typeof(TEntity)}.", exception);
}
};

delegate Task<bool> Filter(object userContext, object input);
delegate Task<bool> Filter(object userContext, ClaimsPrincipal? userPrincipal, object input);

Dictionary<Type, Filter> funcs = new();

internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerable<TEntity> result, object userContext)
internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerable<TEntity> result, object userContext, ClaimsPrincipal? userPrincipal)
where TEntity : class
{
if (funcs.Count == 0)
Expand All @@ -63,7 +65,7 @@ internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerab
var list = new List<TEntity>();
foreach (var item in result)
{
if (await ShouldInclude(userContext, item, filters))
if (await ShouldInclude(userContext, userPrincipal, item, filters))
{
list.Add(item);
}
Expand All @@ -72,12 +74,12 @@ internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerab
return list;
}

static async Task<bool> ShouldInclude<TEntity>(object userContext, TEntity item, List<AsyncFilter<TEntity>> filters)
static async Task<bool> ShouldInclude<TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity item, List<AsyncFilter<TEntity>> filters)
where TEntity : class
{
foreach (var func in filters)
{
if (!await func(userContext, item))
if (!await func(userContext, userPrincipal, item))
{
return false;
}
Expand All @@ -86,7 +88,7 @@ static async Task<bool> ShouldInclude<TEntity>(object userContext, TEntity item,
return true;
}

internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, TEntity? item)
internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity? item)
where TEntity : class
{
if (item is null)
Expand All @@ -101,7 +103,7 @@ internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, TEn

foreach (var func in FindFilters<TEntity>())
{
if (!await func(userContext, item))
if (!await func(userContext, userPrincipal, item))
{
return false;
}
Expand All @@ -116,7 +118,7 @@ IEnumerable<AsyncFilter<TEntity>> FindFilters<TEntity>()
var type = typeof(TEntity);
foreach (var pair in funcs.Where(x => x.Key.IsAssignableFrom(type)))
{
yield return (context, item) => pair.Value(context, item);
yield return (userContext, userPrincipal, item) => pair.Value(userContext, userPrincipal, item);
}
}
}
7 changes: 4 additions & 3 deletions src/GraphQL.EntityFramework/Filters/NullFilters.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
using GraphQL.EntityFramework;
using System.Security.Claims;
using GraphQL.EntityFramework;

class NullFilters :
Filters
{
public static NullFilters Instance = new();

internal override Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerable<TEntity> result, object userContext) =>
internal override Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerable<TEntity> result, object userContext, ClaimsPrincipal? userPrincipal) =>
Task.FromResult(result);

internal override Task<bool> ShouldInclude<TEntity>(object userContext, TEntity? item)
internal override Task<bool> ShouldInclude<TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity? item)
where TEntity : class =>
Task.FromResult(true);
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public FieldBuilder<TSource, TReturn> AddNavigationField<TSource, TReturn>(
var fieldContext = BuildContext(context);

var result = resolve(fieldContext);
if (await fieldContext.Filters.ShouldInclude(context.UserContext, result))
if (await fieldContext.Filters.ShouldInclude(context.UserContext, context.User, result))
{
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ ConnectionBuilder<TSource> AddEnumerableConnection<TSource, TGraph, TReturn>(
var enumerable = resolve(efFieldContext);

enumerable = enumerable.ApplyGraphQlArguments(hasId, context);
enumerable = await efFieldContext.Filters.ApplyFilter(enumerable, context.UserContext);
enumerable = await efFieldContext.Filters.ApplyFilter(enumerable, context.UserContext, context.User);
var page = enumerable.ToList();

return ConnectionConverter.ApplyConnectionContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public FieldBuilder<TSource, TReturn> AddNavigationListField<TSource, TReturn>(
var fieldContext = BuildContext(context);
var result = resolve(fieldContext);
result = result.ApplyGraphQlArguments(hasId, context);
return await fieldContext.Filters.ApplyFilter(result, context.UserContext);
return await fieldContext.Filters.ApplyFilter(result, context.UserContext, context.User);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ FieldType BuildQueryField<TSource, TReturn>(
.ToListAsync(context.CancellationToken);
}

return await fieldContext.Filters.ApplyFilter(list, context.UserContext);
return await fieldContext.Filters.ApplyFilter(list, context.UserContext, context.User);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ FieldType BuildSingleField<TSource, TReturn>(

if (single is not null)
{
if (await efFieldContext.Filters.ShouldInclude(context.UserContext, single))
if (await efFieldContext.Filters.ShouldInclude(context.UserContext, context.User, single))
{
if (mutate is not null)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Snippets/GlobalFilterSnippets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public void Add(ServiceCollection services)

var filters = new Filters();
filters.Add<MyEntity>(
(userContext, item) => item.Property != "Ignore");
(userContext, userPrincipal, item) => item.Property != "Ignore");
EfGraphQLConventions.RegisterInContainer<MyDbContext>(
services,
resolveFilters: x => filters);
Expand Down
34 changes: 17 additions & 17 deletions src/Tests/GlobalFiltersTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@ public class GlobalFiltersTests
public async Task Simple()
{
var filters= new Filters();
filters.Add<Target>((_, target) => target.Property != "Ignore");
Assert.True(await filters.ShouldInclude(new(), new Target()));
Assert.False(await filters.ShouldInclude<object>(new(), null));
Assert.True(await filters.ShouldInclude(new(), new Target {Property = "Include"}));
Assert.False(await filters.ShouldInclude(new(), new Target {Property = "Ignore"}));

filters.Add<BaseTarget>((_, target) => target.Property != "Ignore");
Assert.True(await filters.ShouldInclude(new(), new ChildTarget()));
Assert.True(await filters.ShouldInclude(new(), new ChildTarget {Property = "Include"}));
Assert.False(await filters.ShouldInclude(new(), new ChildTarget {Property = "Ignore"}));

filters.Add<ITarget>((_, target) => target.Property != "Ignore");
Assert.True(await filters.ShouldInclude(new(), new ImplementationTarget()));
Assert.True(await filters.ShouldInclude(new(), new ImplementationTarget { Property = "Include"}));
Assert.False(await filters.ShouldInclude(new(), new ImplementationTarget { Property = "Ignore" }));

Assert.True(await filters.ShouldInclude(new(), new NonTarget { Property = "Foo" }));
filters.Add<Target>((_, _, target) => target.Property != "Ignore");
Assert.True(await filters.ShouldInclude(new(), null, new Target()));
Assert.False(await filters.ShouldInclude<object>(new(), null, null));
Assert.True(await filters.ShouldInclude(new(), null, new Target {Property = "Include"}));
Assert.False(await filters.ShouldInclude(new(), null, new Target {Property = "Ignore"}));

filters.Add<BaseTarget>((_, _, target) => target.Property != "Ignore");
Assert.True(await filters.ShouldInclude(new(), null, new ChildTarget()));
Assert.True(await filters.ShouldInclude(new(), null, new ChildTarget {Property = "Include"}));
Assert.False(await filters.ShouldInclude(new(), null, new ChildTarget {Property = "Ignore"}));

filters.Add<ITarget>((_, _, target) => target.Property != "Ignore");
Assert.True(await filters.ShouldInclude(new(), null, new ImplementationTarget()));
Assert.True(await filters.ShouldInclude(new(), null, new ImplementationTarget { Property = "Include"}));
Assert.False(await filters.ShouldInclude(new(), null, new ImplementationTarget { Property = "Ignore" }));

Assert.True(await filters.ShouldInclude(new(), null, new NonTarget { Property = "Foo" }));
}

public class NonTarget
Expand Down
4 changes: 2 additions & 2 deletions src/Tests/IntegrationTests/IntegrationTests_filtered.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ public async Task Child_filtered()
static Filters BuildFilters()
{
var filters = new Filters();
filters.Add<FilterParentEntity>((_, item) => item.Property != "Ignore");
filters.Add<FilterChildEntity>((_, item) => item.Property != "Ignore");
filters.Add<FilterParentEntity>((_, _, item) => item.Property != "Ignore");
filters.Add<FilterChildEntity>((_, _, item) => item.Property != "Ignore");
return filters;
}

Expand Down