diff --git a/src/libraries/System.Net.Requests/src/System/Net/HttpWebRequest.cs b/src/libraries/System.Net.Requests/src/System/Net/HttpWebRequest.cs index 7beaa7996cca75..ff9527f9bfed57 100644 --- a/src/libraries/System.Net.Requests/src/System/Net/HttpWebRequest.cs +++ b/src/libraries/System.Net.Requests/src/System/Net/HttpWebRequest.cs @@ -8,6 +8,7 @@ using System.Net.Cache; using System.Net.Http; using System.Net.Security; +using System.Net.Sockets; using System.Runtime.Serialization; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; @@ -97,11 +98,13 @@ private enum Booleans : uint private class HttpClientParameters { + public readonly bool Async; public readonly DecompressionMethods AutomaticDecompression; public readonly bool AllowAutoRedirect; public readonly int MaximumAutomaticRedirections; public readonly int MaximumResponseHeadersLength; public readonly bool PreAuthenticate; + public readonly int ReadWriteTimeout; public readonly TimeSpan Timeout; public readonly SecurityProtocolType SslProtocols; public readonly bool CheckCertificateRevocationList; @@ -111,13 +114,15 @@ private class HttpClientParameters public readonly X509CertificateCollection? ClientCertificates; public readonly CookieContainer? CookieContainer; - public HttpClientParameters(HttpWebRequest webRequest) + public HttpClientParameters(HttpWebRequest webRequest, bool async) { + Async = async; AutomaticDecompression = webRequest.AutomaticDecompression; AllowAutoRedirect = webRequest.AllowAutoRedirect; MaximumAutomaticRedirections = webRequest.MaximumAutomaticRedirections; MaximumResponseHeadersLength = webRequest.MaximumResponseHeadersLength; PreAuthenticate = webRequest.PreAuthenticate; + ReadWriteTimeout = webRequest.ReadWriteTimeout; Timeout = webRequest.Timeout == Threading.Timeout.Infinite ? Threading.Timeout.InfiniteTimeSpan : TimeSpan.FromMilliseconds(webRequest.Timeout); @@ -132,11 +137,13 @@ public HttpClientParameters(HttpWebRequest webRequest) public bool Matches(HttpClientParameters requestParameters) { - return AutomaticDecompression == requestParameters.AutomaticDecompression + return Async == requestParameters.Async + && AutomaticDecompression == requestParameters.AutomaticDecompression && AllowAutoRedirect == requestParameters.AllowAutoRedirect && MaximumAutomaticRedirections == requestParameters.MaximumAutomaticRedirections && MaximumResponseHeadersLength == requestParameters.MaximumResponseHeadersLength && PreAuthenticate == requestParameters.PreAuthenticate + && ReadWriteTimeout == requestParameters.ReadWriteTimeout && Timeout == requestParameters.Timeout && SslProtocols == requestParameters.SslProtocols && CheckCertificateRevocationList == requestParameters.CheckCertificateRevocationList @@ -1122,7 +1129,7 @@ private async Task SendRequest(bool async) HttpClient? client = null; try { - client = GetCachedOrCreateHttpClient(out disposeRequired); + client = GetCachedOrCreateHttpClient(async, out disposeRequired); if (_requestStream != null) { ArraySegment bytes = _requestStream.GetBuffer(); @@ -1443,9 +1450,9 @@ private bool TryGetHostUri(string hostName, [NotNullWhen(true)] out Uri? hostUri return Uri.TryCreate(s, UriKind.Absolute, out hostUri); } - private HttpClient GetCachedOrCreateHttpClient(out bool disposeRequired) + private HttpClient GetCachedOrCreateHttpClient(bool async, out bool disposeRequired) { - var parameters = new HttpClientParameters(this); + var parameters = new HttpClientParameters(this, async); if (parameters.AreParametersAcceptableForCaching()) { disposeRequired = false; @@ -1477,7 +1484,7 @@ private static HttpClient CreateHttpClient(HttpClientParameters parameters, Http HttpClient? client = null; try { - var handler = new HttpClientHandler(); + var handler = new SocketsHttpHandler(); client = new HttpClient(handler); handler.AutomaticDecompression = parameters.AutomaticDecompression; handler.Credentials = parameters.Credentials; @@ -1528,20 +1535,55 @@ private static HttpClient CreateHttpClient(HttpClientParameters parameters, Http if (parameters.ClientCertificates != null) { - handler.ClientCertificates.AddRange(parameters.ClientCertificates); + handler.SslOptions.ClientCertificates = new X509CertificateCollection(parameters.ClientCertificates); } // Set relevant properties from ServicePointManager - handler.SslProtocols = (SslProtocols)parameters.SslProtocols; - handler.CheckCertificateRevocationList = parameters.CheckCertificateRevocationList; + handler.SslOptions.EnabledSslProtocols = (SslProtocols)parameters.SslProtocols; + handler.SslOptions.CertificateRevocationCheckMode = parameters.CheckCertificateRevocationList ? X509RevocationMode.Online : X509RevocationMode.NoCheck; RemoteCertificateValidationCallback? rcvc = parameters.ServerCertificateValidationCallback; if (rcvc != null) { - RemoteCertificateValidationCallback localRcvc = rcvc; - HttpWebRequest localRequest = request!; - handler.ServerCertificateCustomValidationCallback = (message, cert, chain, errors) => localRcvc(localRequest, cert, chain, errors); + handler.SslOptions.RemoteCertificateValidationCallback = (message, cert, chain, errors) => rcvc(request!, cert, chain, errors); } + // Set up a ConnectCallback so that we can control Socket-specific settings, like ReadWriteTimeout => socket.Send/ReceiveTimeout. + handler.ConnectCallback = async (context, cancellationToken) => + { + var socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + + try + { + socket.NoDelay = true; + if (parameters.ReadWriteTimeout > 0) // default is 5 minutes, so this is generally going to be true + { + socket.SendTimeout = socket.ReceiveTimeout = parameters.ReadWriteTimeout; + } + + if (parameters.Async) + { + await socket.ConnectAsync(context.DnsEndPoint, cancellationToken).ConfigureAwait(false); + } + else + { + using (cancellationToken.UnsafeRegister(s => ((Socket)s!).Dispose(), socket)) + { + socket.Connect(context.DnsEndPoint); + } + + // Throw in case cancellation caused the socket to be disposed after the Connect completed + cancellationToken.ThrowIfCancellationRequested(); + } + } + catch + { + socket.Dispose(); + throw; + } + + return new NetworkStream(socket, ownsSocket: true); + }; + return client; } catch diff --git a/src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs b/src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs index f124d686cf01f7..9e0bae66d5fac3 100644 --- a/src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs +++ b/src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs @@ -1063,6 +1063,45 @@ public void ReadWriteTimeout_NegativeOrZeroValue_Fail() Assert.Throws(() => { request.ReadWriteTimeout = -10; }); } + [Fact] + public async Task ReadWriteTimeout_CancelsResponse() + { + var tcs = new TaskCompletionSource(); + await LoopbackServer.CreateClientAndServerAsync(uri => Task.Run(async () => + { + try + { + HttpWebRequest request = WebRequest.CreateHttp(uri); + request.ReadWriteTimeout = 10; + IOException e = await Assert.ThrowsAsync(async () => // exception type is WebException on .NET Framework + { + using WebResponse response = await GetResponseAsync(request); + using (Stream myStream = response.GetResponseStream()) + { + while (myStream.ReadByte() != -1) ; + } + }); + Assert.True(e.InnerException is SocketException se && se.SocketErrorCode == SocketError.TimedOut); + } + finally + { + tcs.SetResult(); + } + }), async server => + { + try + { + await server.AcceptConnectionAsync(async connection => + { + await connection.ReadRequestHeaderAsync(); + await connection.WriteStringAsync("HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\nHello Wor"); + await tcs.Task; + }); + } + catch { } + }); + } + [Theory, MemberData(nameof(EchoServers))] public void CookieContainer_SetThenGetContainer_Success(Uri remoteServer) { @@ -1668,7 +1707,7 @@ public void GetResponseAsync_ParametersAreNotCachable_CreateNewClient(HttpWebReq Task secondAccept = listener.AcceptAsync(); - Task secondResponseTask = request1.GetResponseAsync(); + Task secondResponseTask = bool.Parse(async) ? request1.GetResponseAsync() : Task.Run(() => request1.GetResponse()); await ReplyToClient(responseContent, server, serverReader); if (bool.Parse(connectionReusedString)) {