diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 62f87635face..4622220c2087 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -39,6 +39,7 @@ public static partial class RequestDelegateFactory private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo StringResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteWriteStringResponseAsync), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo JsonResultWriteResponseAsyncMethod = GetMethodInfo>((response, value) => HttpResponseJsonExtensions.WriteAsJsonAsync(response, value, default)); + private static readonly MethodInfo StringIsNullOrEmptyMethod = typeof(string).GetMethod(nameof(string.IsNullOrEmpty), BindingFlags.Static | BindingFlags.Public)!; private static readonly MethodInfo LogParameterBindingFailedMethod = GetMethodInfo>((httpContext, parameterType, parameterName, sourceValue, shouldThrow) => Log.ParameterBindingFailed(httpContext, parameterType, parameterName, sourceValue, shouldThrow)); @@ -71,6 +72,8 @@ public static partial class RequestDelegateFactory private static readonly ParameterExpression TempSourceStringExpr = ParameterBindingMethodCache.TempSourceStringExpr; private static readonly BinaryExpression TempSourceStringNotNullExpr = Expression.NotEqual(TempSourceStringExpr, Expression.Constant(null)); private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TempSourceStringExpr, Expression.Constant(null)); + private static readonly UnaryExpression TempSourceStringIsNotNullOrEmptyExpr = Expression.Not(Expression.Call(StringIsNullOrEmptyMethod, TempSourceStringExpr)); + private static readonly string[] DefaultAcceptsContentType = new[] { "application/json" }; private static readonly string[] FormFileContentType = new[] { "multipart/form-data" }; @@ -202,6 +205,7 @@ private static Expression[] CreateArguments(ParameterInfo[]? parameters, Factory var errorMessage = BuildErrorMessageForInferredBodyParameter(factoryContext); throw new InvalidOperationException(errorMessage); } + if (factoryContext.JsonRequestBodyParameter is not null && factoryContext.FirstFormRequestBodyParameter is not null) { @@ -317,7 +321,7 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext { return BindParameterFromBindAsync(parameter, factoryContext); } - else if (parameter.ParameterType == typeof(string) || ParameterBindingMethodCache.HasTryParseMethod(parameter)) + else if (parameter.ParameterType == typeof(string) || ParameterBindingMethodCache.HasTryParseMethod(parameter.ParameterType)) { // 1. We bind from route values only, if route parameters are non-null and the parameter name is in that set. // 2. We bind from query only, if route parameters are non-null and the parameter name is NOT in that set. @@ -342,6 +346,16 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.RouteOrQueryStringParameter); return BindParameterFromRouteValueOrQueryString(parameter, parameter.Name, factoryContext); } + else if (factoryContext.DisableInferredFromBody && ( + (parameter.ParameterType.IsArray && ParameterBindingMethodCache.HasTryParseMethod(parameter.ParameterType.GetElementType()!)) || + parameter.ParameterType == typeof(string[]) || + parameter.ParameterType == typeof(StringValues))) + { + // We only infer parameter types if you have an array of TryParsables/string[]/StringValues, and DisableInferredFromBody is true + + factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.QueryStringParameter); + return BindParameterFromProperty(parameter, QueryExpr, parameter.Name, factoryContext, "query string"); + } else { if (factoryContext.ServiceProviderIsService is IServiceProviderIsService serviceProviderIsService) @@ -884,22 +898,24 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres var parameterNameConstant = Expression.Constant(parameter.Name); var sourceConstant = Expression.Constant(source); - if (parameter.ParameterType == typeof(string)) + if (parameter.ParameterType == typeof(string) || parameter.ParameterType == typeof(string[]) || parameter.ParameterType == typeof(StringValues)) { return BindParameterFromExpression(parameter, valueExpression, factoryContext, source); } factoryContext.UsingTempSourceString = true; - var underlyingNullableType = Nullable.GetUnderlyingType(parameter.ParameterType); + var targetParseType = parameter.ParameterType.IsArray ? parameter.ParameterType.GetElementType()! : parameter.ParameterType; + + var underlyingNullableType = Nullable.GetUnderlyingType(targetParseType); var isNotNullable = underlyingNullableType is null; - var nonNullableParameterType = underlyingNullableType ?? parameter.ParameterType; + var nonNullableParameterType = underlyingNullableType ?? targetParseType; var tryParseMethodCall = ParameterBindingMethodCache.FindTryParseMethod(nonNullableParameterType); if (tryParseMethodCall is null) { - var typeName = TypeNameHelper.GetTypeDisplayName(parameter.ParameterType, fullName: false); + var typeName = TypeNameHelper.GetTypeDisplayName(targetParseType, fullName: false); throw new InvalidOperationException($"No public static bool {typeName}.TryParse(string, out {typeName}) method found for {parameter.Name}."); } @@ -940,8 +956,32 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres // param2_local = 42; // } + // string[]? values = httpContext.Request.Query["param1"].ToArray(); + // int[] param_local = values.Length > 0 ? new int[values.Length] : Array.Empty(); + + // if (values != null) + // { + // int index = 0; + // while (index < values.Length) + // { + // tempSourceString = values[i]; + // if (int.TryParse(tempSourceString, out var parsedValue)) + // { + // param_local[i] = parsedValue; + // } + // else + // { + // wasParamCheckFailure = true; + // Log.ParameterBindingFailed(httpContext, "Int32[]", "param1", tempSourceString); + // break; + // } + // + // index++ + // } + // } + // If the parameter is nullable, create a "parsedValue" local to TryParse into since we cannot use the parameter directly. - var parsedValue = isNotNullable ? argument : Expression.Variable(nonNullableParameterType, "parsedValue"); + var parsedValue = Expression.Variable(nonNullableParameterType, "parsedValue"); var failBlock = Expression.Block( Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)), @@ -970,33 +1010,104 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres ) ); + var index = Expression.Variable(typeof(int), "index"); + // If the parameter is nullable, we need to assign the "parsedValue" local to the nullable parameter on success. - Expression tryParseExpression = isNotNullable ? - Expression.IfThen(Expression.Not(tryParseCall), failBlock) : - Expression.Block(new[] { parsedValue }, + var tryParseExpression = Expression.Block(new[] { parsedValue }, Expression.IfThenElse(tryParseCall, - Expression.Assign(argument, Expression.Convert(parsedValue, parameter.ParameterType)), + Expression.Assign(parameter.ParameterType.IsArray ? Expression.ArrayAccess(argument, index) : argument, Expression.Convert(parsedValue, targetParseType)), failBlock)); - var ifNotNullTryParse = !parameter.HasDefaultValue ? - Expression.IfThen(TempSourceStringNotNullExpr, tryParseExpression) : - Expression.IfThenElse(TempSourceStringNotNullExpr, - tryParseExpression, - Expression.Assign(argument, Expression.Constant(parameter.DefaultValue))); + var ifNotNullTryParse = !parameter.HasDefaultValue + ? Expression.IfThen(TempSourceStringNotNullExpr, tryParseExpression) + : Expression.IfThenElse(TempSourceStringNotNullExpr, tryParseExpression, + Expression.Assign(argument, + Expression.Constant(parameter.DefaultValue))); + + var loopExit = Expression.Label(); + + // REVIEW: We can reuse this like we reuse temp source string + var stringArrayExpr = parameter.ParameterType.IsArray ? Expression.Variable(typeof(string[]), "tempStringArray") : null; + var elementTypeNullabilityInfo = parameter.ParameterType.IsArray ? factoryContext.NullabilityContext.Create(parameter)?.ElementType : null; + + // Determine optionality of the element type of the array + var elementTypeOptional = !isNotNullable || (elementTypeNullabilityInfo?.ReadState != NullabilityState.NotNull); + + // The loop that populates the resulting array values + var arrayLoop = parameter.ParameterType.IsArray ? Expression.Block( + // param_local = new int[values.Length]; + Expression.Assign(argument, Expression.NewArrayBounds(parameter.ParameterType.GetElementType()!, Expression.ArrayLength(stringArrayExpr!))), + // index = 0 + Expression.Assign(index, Expression.Constant(0)), + // while (index < values.Length) + Expression.Loop( + Expression.Block( + Expression.IfThenElse( + Expression.LessThan(index, Expression.ArrayLength(stringArrayExpr!)), + // tempSourceString = values[index]; + Expression.Block( + Expression.Assign(TempSourceStringExpr, Expression.ArrayIndex(stringArrayExpr!, index)), + elementTypeOptional ? Expression.IfThen(TempSourceStringIsNotNullOrEmptyExpr, tryParseExpression) + : tryParseExpression + ), + // else break + Expression.Break(loopExit) + ), + // index++ + Expression.PostIncrementAssign(index) + ) + , loopExit) + ) : null; + + var fullParamCheckBlock = (parameter.ParameterType.IsArray, isOptional) switch + { + // (isArray: true, optional: true) + (true, true) => + + Expression.Block( + new[] { index, stringArrayExpr! }, + // values = httpContext.Request.Query["id"]; + Expression.Assign(stringArrayExpr!, valueExpression), + Expression.IfThen( + Expression.NotEqual(stringArrayExpr!, Expression.Constant(null)), + arrayLoop! + ) + ), + + // (isArray: true, optional: false) + (true, false) => + + Expression.Block( + new[] { index, stringArrayExpr! }, + // values = httpContext.Request.Query["id"]; + Expression.Assign(stringArrayExpr!, valueExpression), + Expression.IfThenElse( + Expression.NotEqual(stringArrayExpr!, Expression.Constant(null)), + arrayLoop!, + failBlock + ) + ), - var fullParamCheckBlock = !isOptional - ? Expression.Block( + // (isArray: false, optional: false) + (false, false) => + + Expression.Block( // tempSourceString = httpContext.RequestValue["id"]; Expression.Assign(TempSourceStringExpr, valueExpression), // if (tempSourceString == null) { ... } only produced when parameter is required checkRequiredParaseableParameterBlock, // if (tempSourceString != null) { ... } - ifNotNullTryParse) - : Expression.Block( + ifNotNullTryParse), + + // (isArray: false, optional: true) + (false, true) => + + Expression.Block( // tempSourceString = httpContext.RequestValue["id"]; Expression.Assign(TempSourceStringExpr, valueExpression), // if (tempSourceString != null) { ... } - ifNotNullTryParse); + ifNotNullTryParse) + }; factoryContext.ExtraLocals.Add(argument); factoryContext.ParamCheckExpressions.Add(fullParamCheckBlock); @@ -1065,7 +1176,12 @@ private static Expression BindParameterFromExpression( } private static Expression BindParameterFromProperty(ParameterInfo parameter, MemberExpression property, string key, FactoryContext factoryContext, string source) => - BindParameterFromValue(parameter, GetValueFromProperty(property, key), factoryContext, source); + BindParameterFromValue(parameter, GetValueFromProperty(property, key, GetExpressionType(parameter.ParameterType)), factoryContext, source); + + private static Type? GetExpressionType(Type type) => + type.IsArray ? typeof(string[]) : + type == typeof(StringValues) ? typeof(StringValues) : + null; private static Expression BindParameterFromRouteValueOrQueryString(ParameterInfo parameter, string key, FactoryContext factoryContext) { @@ -1077,7 +1193,6 @@ private static Expression BindParameterFromRouteValueOrQueryString(ParameterInfo private static Expression BindParameterFromBindAsync(ParameterInfo parameter, FactoryContext factoryContext) { // We reference the boundValues array by parameter index here - var nullability = factoryContext.NullabilityContext.Create(parameter); var isOptional = IsOptionalParameter(parameter, factoryContext); // Get the BindAsync method for the type. diff --git a/src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs b/src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs index dad9ad33752f..0bbd48d1f9f3 100644 --- a/src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs +++ b/src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs @@ -144,7 +144,7 @@ public static IEnumerable TryParseStringParameterInfoData [MemberData(nameof(TryParseStringParameterInfoData))] public void HasTryParseStringMethod_ReturnsTrueWhenMethodExists(ParameterInfo parameterInfo) { - Assert.True(new ParameterBindingMethodCache().HasTryParseMethod(parameterInfo)); + Assert.True(new ParameterBindingMethodCache().HasTryParseMethod(parameterInfo.ParameterType)); } [Fact] diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index ef22b01c11a0..87589f9ce955 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -451,6 +451,57 @@ void TestAction([FromRoute] int foo) Assert.Equal(400, httpContext.Response.StatusCode); } + public static object?[][] TryParsableArrayParameters + { + get + { + static void Store(HttpContext httpContext, T tryParsable) + { + httpContext.Items["tryParsable"] = tryParsable; + } + + var now = DateTime.Now; + + return new[] + { + // string is not technically "TryParsable", but it's the special case. + new object[] { (Action)Store, new[] { "plain string" }, new[] { "plain string" } }, + new object[] { (Action)Store, new[] { "1", "2", "3" }, new StringValues(new[] { "1", "2", "3" }) }, + new object[] { (Action)Store, new[] { "-1", "2", "3" }, new[] { -1,2,3 } }, + new object[] { (Action)Store, new[] { "1","42","32"}, new[] { 1U, 42U, 32U } }, + new object[] { (Action)Store, new[] { "true", "false" }, new[] { true, false } }, + new object[] { (Action)Store, new[] { "-42" }, new[] { (short)-42 } }, + new object[] { (Action)Store, new[] { "42" }, new[] { (ushort)42 } }, + new object[] { (Action)Store, new[] { "-42" }, new[] { -42L } }, + new object[] { (Action)Store, new[] { "42" }, new[] { 42UL } }, + new object[] { (Action)Store, new[] { "-42" },new[] { new IntPtr(-42) } }, + new object[] { (Action)Store, new[] { "A" }, new[] { 'A' } }, + new object[] { (Action)Store, new[] { "0.5" },new[] { 0.5 } }, + new object[] { (Action)Store, new[] { "0.5" },new[] { 0.5f } }, + new object[] { (Action)Store, new[] { "0.5" }, new[] { (Half)0.5f } }, + new object[] { (Action)Store, new[] { "0.5" },new[] { 0.5m } }, + new object[] { (Action)Store, new[] { now.ToString("o") },new[] { now.ToUniversalTime() } }, + new object[] { (Action)Store, new[] { "1970-01-01T00:00:00.0000000+00:00" },new[] { DateTimeOffset.UnixEpoch } }, + new object[] { (Action)Store, new[] { "00:00:42" },new[] { TimeSpan.FromSeconds(42) } }, + new object[] { (Action)Store, new[] { "00000000-0000-0000-0000-000000000000" },new[] { Guid.Empty } }, + new object[] { (Action)Store, new[] { "6.0.0.42" }, new[] { new Version("6.0.0.42") } }, + new object[] { (Action)Store, new[] { "-42" },new[]{ new BigInteger(-42) } }, + new object[] { (Action)Store, new[] { "127.0.0.1" }, new[] { IPAddress.Loopback } }, + new object[] { (Action)Store, new[] { "127.0.0.1:80" },new[] { new IPEndPoint(IPAddress.Loopback, 80) } }, + new object[] { (Action)Store, new[] { "Unix" },new[] { AddressFamily.Unix } }, + new object[] { (Action)Store, new[] { "Nop" }, new[] { ILOpCode.Nop } }, + new object[] { (Action)Store, new[] { "PublicKey,Retargetable" },new[] { AssemblyFlags.PublicKey | AssemblyFlags.Retargetable } }, + new object[] { (Action)Store, new[] { "42" }, new int?[] { 42 } }, + new object[] { (Action)Store, new[] { "ValueB" },new[] { MyEnum.ValueB } }, + new object[] { (Action)Store, new[] { "https://example.org" },new[] { new MyTryParseRecord(new Uri("https://example.org")) } }, + new object?[] { (Action)Store, new string[] {}, Array.Empty() }, + new object?[] { (Action)Store, new string?[] { "1", "2", null, "4" }, new int?[] { 1,2, null, 4 } }, + new object?[] { (Action)Store, new string[] { "1", "2", "", "4" }, new int?[] { 1,2, null, 4 } }, + new object[] { (Action)Store, new[] { "" }, new MyTryParseRecord?[] { null } }, + }; + } + } + public static object?[][] TryParsableParameters { get @@ -702,12 +753,12 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromR [Theory] [MemberData(nameof(TryParsableParameters))] - public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromQueryString(Delegate action, string? routeValue, object? expectedParameterValue) + public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromQueryString(Delegate action, string? queryValue, object? expectedParameterValue) { var httpContext = CreateHttpContext(); httpContext.Request.Query = new QueryCollection(new Dictionary { - ["tryParsable"] = routeValue + ["tryParsable"] = queryValue }); var factoryResult = RequestDelegateFactory.Create(action); @@ -715,9 +766,104 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromQ await requestDelegate(httpContext); + Assert.NotEmpty(httpContext.Items); + Assert.Equal(expectedParameterValue, httpContext.Items["tryParsable"]); + } + + [Theory] + [MemberData(nameof(TryParsableArrayParameters))] + public async Task RequestDelegateHandlesArraysFromQueryString(Delegate action, string[]? queryValues, object? expectedParameterValue) + { + var httpContext = CreateHttpContext(); + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["tryParsable"] = queryValues + }); + + var factoryResult = RequestDelegateFactory.Create(action, new() { DisableInferBodyFromParameters = true }); + + var requestDelegate = factoryResult.RequestDelegate; + + await requestDelegate(httpContext); + + Assert.NotEmpty(httpContext.Items); Assert.Equal(expectedParameterValue, httpContext.Items["tryParsable"]); } + [Theory] + [MemberData(nameof(TryParsableArrayParameters))] + public async Task RequestDelegateHandlesDoesNotHandleArraysFromQueryStringWhenBodyIsInferred(Delegate action, string[]? queryValues, object? expectedParameterValue) + { + var httpContext = CreateHttpContext(); + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["tryParsable"] = queryValues + }); + + var factoryResult = RequestDelegateFactory.Create(action); + + var requestDelegate = factoryResult.RequestDelegate; + + await requestDelegate(httpContext); + + // Assert.NotEmpty(httpContext.Items); + Assert.Null(httpContext.Items["tryParsable"]); + + // Ignore this parameter but we want to reuse the dataset + GC.KeepAlive(expectedParameterValue); + } + + [Fact] + public async Task RequestDelegateHandlesOptionalArraysFromNullQueryString() + { + var httpContext = CreateHttpContext(); + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["tryParsable"] = (string?)null + }); + + static void StoreNullableIntArray(HttpContext httpContext, int?[]? tryParsable) + { + httpContext.Items["tryParsable"] = tryParsable; + } + + var factoryResult = RequestDelegateFactory.Create(StoreNullableIntArray, new() { DisableInferBodyFromParameters = true }); + + var requestDelegate = factoryResult.RequestDelegate; + + await requestDelegate(httpContext); + + Assert.NotEmpty(httpContext.Items); + Assert.Null(httpContext.Items["tryParsable"]); + } + + [Fact] + public async Task RequestDelegateHandlesArraysFromExplicitQueryStringSource() + { + var httpContext = CreateHttpContext(); + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["a"] = new(new[] { "1", "2", "3" }) + }); + + httpContext.Request.Headers["Custom"] = new(new[] { "4", "5", "6" }); + + var factoryResult = RequestDelegateFactory.Create((HttpContext context, + [FromHeader(Name = "Custom")] int[] headerValues, + [FromQuery(Name = "a")] int[] queryValues) => + { + context.Items["headers"] = headerValues; + context.Items["query"] = queryValues; + }); + + var requestDelegate = factoryResult.RequestDelegate; + + await requestDelegate(httpContext); + + Assert.Equal(new[] { 1, 2, 3 }, (int[])httpContext.Items["query"]!); + Assert.Equal(new[] { 4, 5, 6 }, (int[])httpContext.Items["headers"]!); + } + [Fact] public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromRouteValueBeforeQueryString() { @@ -980,6 +1126,76 @@ void TestAction([FromRoute] int tryParsable, [FromRoute] int tryParsable2) Assert.Equal(400, badHttpRequestException.StatusCode); } + [Fact] + public async Task RequestDelegateThrowsForTryParsableFailuresIfThrowOnBadRequestWithArrays() + { + var invoked = false; + + void TestAction([FromQuery] int[] values) + { + invoked = true; + } + + var httpContext = CreateHttpContext(); + httpContext.Request.Query = new QueryCollection(new Dictionary() + { + ["values"] = new(new[] { "1", "NAN", "3" }) + }); + + var factoryResult = RequestDelegateFactory.Create(TestAction, new() { ThrowOnBadRequest = true, DisableInferBodyFromParameters = true }); + var requestDelegate = factoryResult.RequestDelegate; + + var badHttpRequestException = await Assert.ThrowsAsync(() => requestDelegate(httpContext)); + + Assert.False(invoked); + + // The httpContext should be untouched. + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.Response.HasStarted); + + // We don't log bad requests when we throw. + Assert.Empty(TestSink.Writes); + + Assert.Equal(@"Failed to bind parameter ""int[] values"" from ""NAN"".", badHttpRequestException.Message); + Assert.Equal(400, badHttpRequestException.StatusCode); + } + + [Fact] + public async Task RequestDelegateThrowsForTryParsableFailuresIfThrowOnBadRequestWithNonOptionalArrays() + { + var invoked = false; + + void StoreNullableIntArray(HttpContext httpContext, int?[] values) + { + invoked = true; + } + + var httpContext = CreateHttpContext(); + httpContext.Request.Query = new QueryCollection(new Dictionary() + { + ["values"] = (string?)null + }); + + var factoryResult = RequestDelegateFactory.Create(StoreNullableIntArray, new() { ThrowOnBadRequest = true, DisableInferBodyFromParameters = true }); + var requestDelegate = factoryResult.RequestDelegate; + + var badHttpRequestException = await Assert.ThrowsAsync(() => requestDelegate(httpContext)); + + Assert.False(invoked); + + // The httpContext should be untouched. + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.False(httpContext.Response.HasStarted); + + // We don't log bad requests when we throw. + Assert.Empty(TestSink.Writes); + + Assert.Equal(@"Failed to bind parameter ""Nullable[] values"" from """".", badHttpRequestException.Message); + Assert.Equal(400, badHttpRequestException.StatusCode); + } + [Fact] public async Task RequestDelegateLogsBindAsyncFailuresAndSets400Response() { diff --git a/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs b/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs index 9b22fd7c6e6f..57037d5ef14e 100644 --- a/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs +++ b/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs @@ -16,6 +16,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Internal; +using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Mvc.ApiExplorer; @@ -44,6 +45,18 @@ public EndpointMetadataApiDescriptionProvider( public void OnProvidersExecuting(ApiDescriptionProviderContext context) { + // Keep in sync with EndpointRouteBuilderExtensions.cs + static bool ShouldDisableInferredBody(string method) + { + // GET, DELETE, HEAD, CONNECT, TRACE, and OPTIONS normally do not contain bodies + return method.Equals(HttpMethods.Get, StringComparison.Ordinal) || + method.Equals(HttpMethods.Delete, StringComparison.Ordinal) || + method.Equals(HttpMethods.Head, StringComparison.Ordinal) || + method.Equals(HttpMethods.Options, StringComparison.Ordinal) || + method.Equals(HttpMethods.Trace, StringComparison.Ordinal) || + method.Equals(HttpMethods.Connect, StringComparison.Ordinal); + } + foreach (var endpoint in _endpointDataSource.Endpoints) { if (endpoint is RouteEndpoint routeEndpoint && @@ -51,12 +64,15 @@ public void OnProvidersExecuting(ApiDescriptionProviderContext context) routeEndpoint.Metadata.GetMetadata() is { } httpMethodMetadata && routeEndpoint.Metadata.GetMetadata() is null or { ExcludeFromDescription: false }) { + // We need to detect if any of the methods allow inferred body + var disableInferredBody = httpMethodMetadata.HttpMethods.Any(ShouldDisableInferredBody); + // REVIEW: Should we add an ApiDescription for endpoints without IHttpMethodMetadata? Swagger doesn't handle // a null HttpMethod even though it's nullable on ApiDescription, so we'd need to define "default" HTTP methods. // In practice, the Delegate will be called for any HTTP method if there is no IHttpMethodMetadata. foreach (var httpMethod in httpMethodMetadata.HttpMethods) { - context.Results.Add(CreateApiDescription(routeEndpoint, httpMethod, methodInfo)); + context.Results.Add(CreateApiDescription(routeEndpoint, httpMethod, methodInfo, disableInferredBody)); } } } @@ -66,7 +82,7 @@ public void OnProvidersExecuted(ApiDescriptionProviderContext context) { } - private ApiDescription CreateApiDescription(RouteEndpoint routeEndpoint, string httpMethod, MethodInfo methodInfo) + private ApiDescription CreateApiDescription(RouteEndpoint routeEndpoint, string httpMethod, MethodInfo methodInfo, bool disableInferredBody) { // Swashbuckle uses the "controller" name to group endpoints together. // For now, put all methods defined the same declaring type together. @@ -102,7 +118,7 @@ private ApiDescription CreateApiDescription(RouteEndpoint routeEndpoint, string foreach (var parameter in methodInfo.GetParameters()) { - var parameterDescription = CreateApiParameterDescription(parameter, routeEndpoint.RoutePattern); + var parameterDescription = CreateApiParameterDescription(parameter, routeEndpoint.RoutePattern, disableInferredBody); if (parameterDescription is null) { @@ -155,9 +171,9 @@ private ApiDescription CreateApiDescription(RouteEndpoint routeEndpoint, string return apiDescription; } - private ApiParameterDescription? CreateApiParameterDescription(ParameterInfo parameter, RoutePattern pattern) + private ApiParameterDescription? CreateApiParameterDescription(ParameterInfo parameter, RoutePattern pattern, bool disableInferredBody) { - var (source, name, allowEmpty, paramType) = GetBindingSourceAndName(parameter, pattern); + var (source, name, allowEmpty, paramType) = GetBindingSourceAndName(parameter, pattern, disableInferredBody); // Services are ignored because they are not request parameters. if (source == BindingSource.Services) @@ -230,7 +246,7 @@ private static ParameterDescriptor CreateParameterDescriptor(ParameterInfo param // TODO: Share more of this logic with RequestDelegateFactory.CreateArgument(...) using RequestDelegateFactoryUtilities // which is shared source. - private (BindingSource, string, bool, Type) GetBindingSourceAndName(ParameterInfo parameter, RoutePattern pattern) + private (BindingSource, string, bool, Type) GetBindingSourceAndName(ParameterInfo parameter, RoutePattern pattern, bool disableInferredBody) { var attributes = parameter.GetCustomAttributes(); @@ -265,7 +281,7 @@ private static ParameterDescriptor CreateParameterDescriptor(ParameterInfo param { return (BindingSource.Services, parameter.Name ?? string.Empty, false, parameter.ParameterType); } - else if (parameter.ParameterType == typeof(string) || ParameterBindingMethodCache.HasTryParseMethod(parameter)) + else if (parameter.ParameterType == typeof(string) || ParameterBindingMethodCache.HasTryParseMethod(parameter.ParameterType)) { // complex types will display as strings since they use custom parsing via TryParse on a string var displayType = !parameter.ParameterType.IsPrimitive && Nullable.GetUnderlyingType(parameter.ParameterType)?.IsPrimitive != true @@ -284,6 +300,13 @@ private static ParameterDescriptor CreateParameterDescriptor(ParameterInfo param { return (BindingSource.FormFile, parameter.Name ?? string.Empty, false, parameter.ParameterType); } + else if (disableInferredBody && ( + (parameter.ParameterType.IsArray && ParameterBindingMethodCache.HasTryParseMethod(parameter.ParameterType.GetElementType()!)) || + parameter.ParameterType == typeof(string[]) || + parameter.ParameterType == typeof(StringValues))) + { + return (BindingSource.Query, parameter.Name ?? string.Empty, false, parameter.ParameterType); + } else { return (BindingSource.Body, parameter.Name ?? string.Empty, false, parameter.ParameterType); diff --git a/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs b/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs index aef0b7ba2e4b..56c00a475b82 100644 --- a/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs +++ b/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs @@ -16,6 +16,7 @@ using Microsoft.Extensions.FileProviders; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Mvc.ApiExplorer; @@ -329,16 +330,39 @@ static void AssertPathParameter(ApiDescription apiDescription) [Fact] public void AddsFromQueryParameterAsQuery() { - static void AssertQueryParameter(ApiDescription apiDescription) + static void AssertQueryParameter(ApiDescription apiDescription) { var param = Assert.Single(apiDescription.ParameterDescriptions); - Assert.Equal(typeof(int), param.Type); - Assert.Equal(typeof(int), param.ModelMetadata.ModelType); + Assert.Equal(typeof(T), param.Type); + Assert.Equal(typeof(T), param.ModelMetadata.ModelType); Assert.Equal(BindingSource.Query, param.Source); } - AssertQueryParameter(GetApiDescription((int foo) => { }, "/")); - AssertQueryParameter(GetApiDescription(([FromQuery] int foo) => { })); + AssertQueryParameter(GetApiDescription((int foo) => { }, "/")); + AssertQueryParameter(GetApiDescription(([FromQuery] int foo) => { })); + AssertQueryParameter(GetApiDescription(([FromQuery] TryParseStringRecordStruct foo) => { })); + AssertQueryParameter(GetApiDescription((int[] foo) => { }, "/")); + AssertQueryParameter(GetApiDescription((string[] foo) => { }, "/")); + AssertQueryParameter(GetApiDescription((StringValues foo) => { }, "/")); + AssertQueryParameter(GetApiDescription((TryParseStringRecordStruct[] foo) => { }, "/")); + } + + [Theory] + [InlineData("Put")] + [InlineData("Post")] + public void BodyIsInferredForArraysInsteadOfQuerySomeHttpMethods(string httpMethod) + { + static void AssertBody(ApiDescription apiDescription) + { + var param = Assert.Single(apiDescription.ParameterDescriptions); + Assert.Equal(typeof(T), param.Type); + Assert.Equal(typeof(T), param.ModelMetadata.ModelType); + Assert.Equal(BindingSource.Body, param.Source); + } + + AssertBody(GetApiDescription((int[] foo) => { }, "/", httpMethods: new[] { httpMethod })); + AssertBody(GetApiDescription((string[] foo) => { }, "/", httpMethods: new[] { httpMethod })); + AssertBody(GetApiDescription((TryParseStringRecordStruct[] foo) => { }, "/", httpMethods: new[] { httpMethod })); } [Fact] @@ -1163,8 +1187,8 @@ private static IList GetApiDescriptions( private static TestEndpointRouteBuilder CreateBuilder() => new TestEndpointRouteBuilder(new ApplicationBuilder(new TestServiceProvider())); - private static ApiDescription GetApiDescription(Delegate action, string pattern = null, string displayName = null) => - Assert.Single(GetApiDescriptions(action, pattern, displayName: displayName)); + private static ApiDescription GetApiDescription(Delegate action, string pattern = null, string displayName = null, IEnumerable httpMethods = null) => + Assert.Single(GetApiDescriptions(action, pattern, displayName: displayName, httpMethods: httpMethods)); private static void TestAction() { diff --git a/src/Shared/ParameterBindingMethodCache.cs b/src/Shared/ParameterBindingMethodCache.cs index 228ec6fad549..810dd7923618 100644 --- a/src/Shared/ParameterBindingMethodCache.cs +++ b/src/Shared/ParameterBindingMethodCache.cs @@ -43,9 +43,9 @@ public ParameterBindingMethodCache(bool preferNonGenericEnumParseOverload) _enumTryParseMethod = GetEnumTryParseMethod(preferNonGenericEnumParseOverload); } - public bool HasTryParseMethod(ParameterInfo parameter) + public bool HasTryParseMethod(Type type) { - var nonNullableParameterType = Nullable.GetUnderlyingType(parameter.ParameterType) ?? parameter.ParameterType; + var nonNullableParameterType = Nullable.GetUnderlyingType(type) ?? type; return FindTryParseMethod(nonNullableParameterType) is not null; }