From 6c9debcba97afe37a9c146b5957d048c4fb851ff Mon Sep 17 00:00:00 2001 From: Simon Date: Tue, 23 Aug 2022 13:20:40 +1000 Subject: [PATCH] . --- src/Directory.Build.props | 2 +- .../ConnectionConverter.cs | 2 +- .../Filters/Filters.cs | 32 +++++++++-------- .../Filters/NullFilters.cs | 7 ++-- .../GraphApi/EfGraphQLService_Navigation.cs | 2 +- .../EfGraphQLService_NavigationConnection.cs | 2 +- .../EfGraphQLService_NavigationList.cs | 2 +- .../GraphApi/EfGraphQLService_Queryable.cs | 2 +- .../GraphApi/EfGraphQLService_Single.cs | 2 +- src/Snippets/GlobalFilterSnippets.cs | 2 +- src/Tests/GlobalFiltersTests.cs | 34 +++++++++---------- .../IntegrationTests_filtered.cs | 4 +-- 12 files changed, 48 insertions(+), 45 deletions(-) diff --git a/src/Directory.Build.props b/src/Directory.Build.props index d696f972c..f97d389f0 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -2,7 +2,7 @@ CS1591;NU5104;CS1573 - 18.1.0 + 19.0.0 1.0.0 EntityFrameworkCore, EntityFramework, GraphQL false diff --git a/src/GraphQL.EntityFramework/ConnectionConverter.cs b/src/GraphQL.EntityFramework/ConnectionConverter.cs index d3af04b1c..5e6eed0cc 100644 --- a/src/GraphQL.EntityFramework/ConnectionConverter.cs +++ b/src/GraphQL.EntityFramework/ConnectionConverter.cs @@ -172,7 +172,7 @@ static async Task> Range( var page = list.Skip(skip).Take(take); QueryLogger.Write(page); IEnumerable result = await page.ToListAsync(cancellation); - result = await filters.ApplyFilter(result, context.UserContext); + result = await filters.ApplyFilter(result, context.UserContext, context.User); cancellation.ThrowIfCancellationRequested(); return Build(skip, take, count, result); diff --git a/src/GraphQL.EntityFramework/Filters/Filters.cs b/src/GraphQL.EntityFramework/Filters/Filters.cs index d775b7766..bb7e6187a 100644 --- a/src/GraphQL.EntityFramework/Filters/Filters.cs +++ b/src/GraphQL.EntityFramework/Filters/Filters.cs @@ -1,13 +1,15 @@ -namespace GraphQL.EntityFramework; +using System.Security.Claims; + +namespace GraphQL.EntityFramework; #region FiltersSignature public class Filters { - public delegate bool Filter(object userContext, TEntity input) + public delegate bool Filter(object userContext, ClaimsPrincipal? userPrincipal, TEntity input) where TEntity : class; - public delegate Task AsyncFilter(object userContext, TEntity input) + public delegate Task AsyncFilter(object userContext, ClaimsPrincipal? userPrincipal, TEntity input) where TEntity : class; #endregion @@ -15,11 +17,11 @@ public delegate Task AsyncFilter(object userContext, TEntity i public void Add(Filter filter) where TEntity : class => funcs[typeof(TEntity)] = - (context, item) => + (userContext, userPrincipal, item) => { try { - return Task.FromResult(filter(context, (TEntity) item)); + return Task.FromResult(filter(userContext, userPrincipal, (TEntity) item)); } catch (Exception exception) { @@ -30,11 +32,11 @@ public void Add(Filter filter) public void Add(AsyncFilter filter) where TEntity : class => funcs[typeof(TEntity)] = - async (context, item) => + async (userContext, userPrincipal, item) => { try { - return await filter(context, (TEntity) item); + return await filter(userContext, userPrincipal, (TEntity) item); } catch (Exception exception) { @@ -42,11 +44,11 @@ public void Add(AsyncFilter filter) } }; - delegate Task Filter(object userContext, object input); + delegate Task Filter(object userContext, ClaimsPrincipal? userPrincipal, object input); Dictionary funcs = new(); - internal virtual async Task> ApplyFilter(IEnumerable result, object userContext) + internal virtual async Task> ApplyFilter(IEnumerable result, object userContext, ClaimsPrincipal? userPrincipal) where TEntity : class { if (funcs.Count == 0) @@ -63,7 +65,7 @@ internal virtual async Task> ApplyFilter(IEnumerab var list = new List(); foreach (var item in result) { - if (await ShouldInclude(userContext, item, filters)) + if (await ShouldInclude(userContext, userPrincipal, item, filters)) { list.Add(item); } @@ -72,12 +74,12 @@ internal virtual async Task> ApplyFilter(IEnumerab return list; } - static async Task ShouldInclude(object userContext, TEntity item, List> filters) + static async Task ShouldInclude(object userContext, ClaimsPrincipal? userPrincipal, TEntity item, List> filters) where TEntity : class { foreach (var func in filters) { - if (!await func(userContext, item)) + if (!await func(userContext, userPrincipal, item)) { return false; } @@ -86,7 +88,7 @@ static async Task ShouldInclude(object userContext, TEntity item, return true; } - internal virtual async Task ShouldInclude(object userContext, TEntity? item) + internal virtual async Task ShouldInclude(object userContext, ClaimsPrincipal? userPrincipal, TEntity? item) where TEntity : class { if (item is null) @@ -101,7 +103,7 @@ internal virtual async Task ShouldInclude(object userContext, TEn foreach (var func in FindFilters()) { - if (!await func(userContext, item)) + if (!await func(userContext, userPrincipal, item)) { return false; } @@ -116,7 +118,7 @@ IEnumerable> FindFilters() var type = typeof(TEntity); foreach (var pair in funcs.Where(x => x.Key.IsAssignableFrom(type))) { - yield return (context, item) => pair.Value(context, item); + yield return (userContext, userPrincipal, item) => pair.Value(userContext, userPrincipal, item); } } } \ No newline at end of file diff --git a/src/GraphQL.EntityFramework/Filters/NullFilters.cs b/src/GraphQL.EntityFramework/Filters/NullFilters.cs index 1f5365fe8..c2928ec59 100644 --- a/src/GraphQL.EntityFramework/Filters/NullFilters.cs +++ b/src/GraphQL.EntityFramework/Filters/NullFilters.cs @@ -1,14 +1,15 @@ -using GraphQL.EntityFramework; +using System.Security.Claims; +using GraphQL.EntityFramework; class NullFilters : Filters { public static NullFilters Instance = new(); - internal override Task> ApplyFilter(IEnumerable result, object userContext) => + internal override Task> ApplyFilter(IEnumerable result, object userContext, ClaimsPrincipal? userPrincipal) => Task.FromResult(result); - internal override Task ShouldInclude(object userContext, TEntity? item) + internal override Task ShouldInclude(object userContext, ClaimsPrincipal? userPrincipal, TEntity? item) where TEntity : class => Task.FromResult(true); } \ No newline at end of file diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Navigation.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Navigation.cs index ea7baf67d..f907a826c 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Navigation.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Navigation.cs @@ -35,7 +35,7 @@ public FieldBuilder AddNavigationField( var fieldContext = BuildContext(context); var result = resolve(fieldContext); - if (await fieldContext.Filters.ShouldInclude(context.UserContext, result)) + if (await fieldContext.Filters.ShouldInclude(context.UserContext, context.User, result)) { return result; } diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationConnection.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationConnection.cs index 55794708c..5bea293d9 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationConnection.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationConnection.cs @@ -49,7 +49,7 @@ ConnectionBuilder AddEnumerableConnection( var enumerable = resolve(efFieldContext); enumerable = enumerable.ApplyGraphQlArguments(hasId, context); - enumerable = await efFieldContext.Filters.ApplyFilter(enumerable, context.UserContext); + enumerable = await efFieldContext.Filters.ApplyFilter(enumerable, context.UserContext, context.User); var page = enumerable.ToList(); return ConnectionConverter.ApplyConnectionContext( diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationList.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationList.cs index ec4b0ad54..307cb4fcb 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationList.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationList.cs @@ -35,7 +35,7 @@ public FieldBuilder AddNavigationListField( var fieldContext = BuildContext(context); var result = resolve(fieldContext); result = result.ApplyGraphQlArguments(hasId, context); - return await fieldContext.Filters.ApplyFilter(result, context.UserContext); + return await fieldContext.Filters.ApplyFilter(result, context.UserContext, context.User); }); } diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Queryable.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Queryable.cs index 4837146e8..d1b9845ca 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Queryable.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Queryable.cs @@ -77,7 +77,7 @@ FieldType BuildQueryField( .ToListAsync(context.CancellationToken); } - return await fieldContext.Filters.ApplyFilter(list, context.UserContext); + return await fieldContext.Filters.ApplyFilter(list, context.UserContext, context.User); }); } diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Single.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Single.cs index 075e915ee..fcb7c6444 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Single.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Single.cs @@ -98,7 +98,7 @@ FieldType BuildSingleField( if (single is not null) { - if (await efFieldContext.Filters.ShouldInclude(context.UserContext, single)) + if (await efFieldContext.Filters.ShouldInclude(context.UserContext, context.User, single)) { if (mutate is not null) { diff --git a/src/Snippets/GlobalFilterSnippets.cs b/src/Snippets/GlobalFilterSnippets.cs index 51478e9ea..ba0fab60d 100644 --- a/src/Snippets/GlobalFilterSnippets.cs +++ b/src/Snippets/GlobalFilterSnippets.cs @@ -20,7 +20,7 @@ public void Add(ServiceCollection services) var filters = new Filters(); filters.Add( - (userContext, item) => item.Property != "Ignore"); + (userContext, userPrincipal, item) => item.Property != "Ignore"); EfGraphQLConventions.RegisterInContainer( services, resolveFilters: x => filters); diff --git a/src/Tests/GlobalFiltersTests.cs b/src/Tests/GlobalFiltersTests.cs index 2a2cdf7d2..f78703c90 100644 --- a/src/Tests/GlobalFiltersTests.cs +++ b/src/Tests/GlobalFiltersTests.cs @@ -6,23 +6,23 @@ public class GlobalFiltersTests public async Task Simple() { var filters= new Filters(); - filters.Add((_, target) => target.Property != "Ignore"); - Assert.True(await filters.ShouldInclude(new(), new Target())); - Assert.False(await filters.ShouldInclude(new(), null)); - Assert.True(await filters.ShouldInclude(new(), new Target {Property = "Include"})); - Assert.False(await filters.ShouldInclude(new(), new Target {Property = "Ignore"})); - - filters.Add((_, target) => target.Property != "Ignore"); - Assert.True(await filters.ShouldInclude(new(), new ChildTarget())); - Assert.True(await filters.ShouldInclude(new(), new ChildTarget {Property = "Include"})); - Assert.False(await filters.ShouldInclude(new(), new ChildTarget {Property = "Ignore"})); - - filters.Add((_, target) => target.Property != "Ignore"); - Assert.True(await filters.ShouldInclude(new(), new ImplementationTarget())); - Assert.True(await filters.ShouldInclude(new(), new ImplementationTarget { Property = "Include"})); - Assert.False(await filters.ShouldInclude(new(), new ImplementationTarget { Property = "Ignore" })); - - Assert.True(await filters.ShouldInclude(new(), new NonTarget { Property = "Foo" })); + filters.Add((_, _, target) => target.Property != "Ignore"); + Assert.True(await filters.ShouldInclude(new(), null, new Target())); + Assert.False(await filters.ShouldInclude(new(), null, null)); + Assert.True(await filters.ShouldInclude(new(), null, new Target {Property = "Include"})); + Assert.False(await filters.ShouldInclude(new(), null, new Target {Property = "Ignore"})); + + filters.Add((_, _, target) => target.Property != "Ignore"); + Assert.True(await filters.ShouldInclude(new(), null, new ChildTarget())); + Assert.True(await filters.ShouldInclude(new(), null, new ChildTarget {Property = "Include"})); + Assert.False(await filters.ShouldInclude(new(), null, new ChildTarget {Property = "Ignore"})); + + filters.Add((_, _, target) => target.Property != "Ignore"); + Assert.True(await filters.ShouldInclude(new(), null, new ImplementationTarget())); + Assert.True(await filters.ShouldInclude(new(), null, new ImplementationTarget { Property = "Include"})); + Assert.False(await filters.ShouldInclude(new(), null, new ImplementationTarget { Property = "Ignore" })); + + Assert.True(await filters.ShouldInclude(new(), null, new NonTarget { Property = "Foo" })); } public class NonTarget diff --git a/src/Tests/IntegrationTests/IntegrationTests_filtered.cs b/src/Tests/IntegrationTests/IntegrationTests_filtered.cs index a46ca46c8..651b76212 100644 --- a/src/Tests/IntegrationTests/IntegrationTests_filtered.cs +++ b/src/Tests/IntegrationTests/IntegrationTests_filtered.cs @@ -40,8 +40,8 @@ public async Task Child_filtered() static Filters BuildFilters() { var filters = new Filters(); - filters.Add((_, item) => item.Property != "Ignore"); - filters.Add((_, item) => item.Property != "Ignore"); + filters.Add((_, _, item) => item.Property != "Ignore"); + filters.Add((_, _, item) => item.Property != "Ignore"); return filters; }