diff --git a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs index 2aa3c26ee16b13..687fa61ab06d28 100644 --- a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs +++ b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs @@ -2116,6 +2116,45 @@ await server.AcceptConnectionAsync(async connection => }); } + [Fact] + public async Task SendAsync_Expect100Continue_RequestBodyFails_ThrowsContentException() + { + if (IsWinHttpHandler && UseVersion >= HttpVersion20.Value) + { + return; + } + if (!TestAsync && UseVersion >= HttpVersion20.Value) + { + return; + } + + var clientFinished = new TaskCompletionSource(); + + await LoopbackServerFactory.CreateClientAndServerAsync(async uri => + { + using (HttpClient client = CreateHttpClient()) + { + HttpRequestMessage initialMessage = new HttpRequestMessage(HttpMethod.Post, uri) { Version = UseVersion }; + initialMessage.Content = new ThrowingContent(() => new ThrowingContentException()); + initialMessage.Headers.ExpectContinue = true; + Exception exception = await Assert.ThrowsAsync(() => client.SendAsync(TestAsync, initialMessage)); + + clientFinished.SetResult(true); + } + }, async server => + { + await server.AcceptConnectionAsync(async connection => + { + try + { + await connection.ReadRequestDataAsync(readBody: true); + } + catch { } // Eat errors from client disconnect. + await clientFinished.Task.TimeoutAfter(TimeSpan.FromMinutes(2)); + }); + }); + } + [Fact] public async Task SendAsync_No100ContinueReceived_RequestBodySentEventually() { diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/ThrowingContent.cs b/src/libraries/Common/tests/System/Net/Http/ThrowingContent.cs similarity index 88% rename from src/libraries/System.Net.Http/tests/FunctionalTests/ThrowingContent.cs rename to src/libraries/Common/tests/System/Net/Http/ThrowingContent.cs index 276570552bd19f..8b757e8830ea94 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/ThrowingContent.cs +++ b/src/libraries/Common/tests/System/Net/Http/ThrowingContent.cs @@ -10,7 +10,7 @@ namespace System.Net.Http.Functional.Tests { /// HttpContent that mocks exceptions on serialization. - public class ThrowingContent : HttpContent + public partial class ThrowingContent : HttpContent { private readonly Func _exnFactory; private readonly int _length; @@ -32,4 +32,7 @@ protected override bool TryComputeLength(out long length) return true; } } + + public class ThrowingContentException : Exception + { } } diff --git a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/System.Net.Http.WinHttpHandler.Functional.Tests.csproj b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/System.Net.Http.WinHttpHandler.Functional.Tests.csproj index b1dbc739244a7c..1d0dceca18b31c 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/System.Net.Http.WinHttpHandler.Functional.Tests.csproj +++ b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/System.Net.Http.WinHttpHandler.Functional.Tests.csproj @@ -129,6 +129,8 @@ Link="Common\System\Net\Http\SchSendAuxRecordHttpTest.cs" /> + SendAsync(HttpRequestMess if (requestBodyTask.IsCompleted || duplex == false || await Task.WhenAny(requestBodyTask, responseHeadersTask).ConfigureAwait(false) == requestBodyTask || - requestBodyTask.IsCompleted) + requestBodyTask.IsCompleted || + http2Stream.SendRequestFinished) { // The sending of the request body completed before receiving all of the request headers (or we're // ok waiting for the request body even if it hasn't completed, e.g. because we're not doing duplex). diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index f9cee658c61e7c..5ac829bbe3dc0a 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -145,6 +145,8 @@ public void Initialize(int streamId, int initialWindowSize) public int StreamId { get; private set; } + public bool SendRequestFinished => _requestCompletionState != StreamCompletionState.InProgress; + public HttpResponseMessage GetAndClearResponse() { // Once SendAsync completes, the Http2Stream should no longer hold onto the response message. diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index 4a9485a472428e..780212df78062e 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -692,6 +692,21 @@ public async Task SendAsyncCore(HttpRequestMessage request, // hook up a continuation that will log it. if (sendRequestContentTask != null && !sendRequestContentTask.IsCompletedSuccessfully) { + // In case the connection is disposed, it's most probable that + // expect100Continue timer expired and request content sending failed. + // We're awaiting the task to propagate the exception in this case. + if (Volatile.Read(ref _disposed) == 1) + { + if (async) + { + await sendRequestContentTask.ConfigureAwait(false); + } + else + { + // No way around it here if we want to get the exception from the task. + sendRequestContentTask.GetAwaiter().GetResult(); + } + } LogExceptions(sendRequestContentTask); } @@ -793,7 +808,8 @@ private async ValueTask SendRequestContentAsync(HttpRequestMessage request, Http } private async Task SendRequestContentWithExpect100ContinueAsync( - HttpRequestMessage request, Task allowExpect100ToContinueTask, HttpContentWriteStream stream, Timer expect100Timer, bool async, CancellationToken cancellationToken) + HttpRequestMessage request, Task allowExpect100ToContinueTask, + HttpContentWriteStream stream, Timer expect100Timer, bool async, CancellationToken cancellationToken) { // Wait until we receive a trigger notification that it's ok to continue sending content. // This will come either when the timer fires or when we receive a response status line from the server. @@ -806,7 +822,17 @@ private async Task SendRequestContentWithExpect100ContinueAsync( if (sendRequestContent) { if (NetEventSource.Log.IsEnabled()) Trace($"Sending request content for Expect: 100-continue."); - await SendRequestContentAsync(request, stream, async, cancellationToken).ConfigureAwait(false); + try + { + await SendRequestContentAsync(request, stream, async, cancellationToken).ConfigureAwait(false); + } + catch + { + // Tear down the connection if called from the timer thread because caller's thread will wait for server status line indefinitely + // or till HttpClient.Timeout tear the connection itself. + Dispose(); + throw; + } } else { diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj b/src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj index dc5bc00ef6add7..2e00a892170297 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj @@ -179,7 +179,6 @@ - @@ -227,7 +226,9 @@ Link="Common\System\Net\Http\SyncBlockingContent.cs" /> - + + diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/ThrowingContent.netcore.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/ThrowingContent.netcore.cs new file mode 100644 index 00000000000000..8f4fd472477034 --- /dev/null +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/ThrowingContent.netcore.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Http.Functional.Tests +{ + public partial class ThrowingContent : HttpContent + { + protected override void SerializeToStream(Stream stream, TransportContext context, CancellationToken cancellationToken) + { + throw _exnFactory(); + } + } +}