diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 553a254c2770d4..9436a7f2731d59 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -510,15 +510,16 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, if (writeTask.IsCompleted) { writeTask.GetAwaiter().GetResult(); - ValueTask flushTask = new ValueTask(_stream.FlushAsync()); + Task flushTask = _stream.FlushAsync(); if (flushTask.IsCompleted) { - return flushTask; + flushTask.GetAwaiter().GetResult(); + return ValueTask.CompletedTask; } else { releaseSendBufferAndSemaphore = false; - return WaitForWriteTaskAsync(flushTask, shouldFlush: false); + return WaitForWriteTaskAsync(new ValueTask(flushTask), shouldFlush: false); } } diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs index 64ec09f88807f2..e5a87210a43166 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs @@ -86,6 +86,12 @@ public Span NextAvailableBytes /// public bool IgnoreCancellationToken { get; set; } + /// + /// If set, causes FlushAsync to return a synchronously-faulted Task with this exception. + /// Used to exercise sync-completion-faulted code paths in the WebSocket send flow. + /// + public Exception? FlushException { get; set; } + public override bool CanRead => true; public override bool CanSeek => false; @@ -226,6 +232,15 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, Cancella public override void Flush() { } + public override Task FlushAsync(CancellationToken cancellationToken) + { + if (FlushException is not null) + { + return Task.FromException(FlushException); + } + return base.FlushAsync(cancellationToken); + } + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs index 41c7fe341d266d..cf01437fc3d97d 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs @@ -182,6 +182,23 @@ public async Task ThrowWhenContinuationWithDifferentCompressionFlags() client.SendAsync(Memory.Empty, WebSocketMessageType.Binary, WebSocketMessageFlags.EndOfMessage, default)); } + [Fact] + public async Task SendAsync_FlushAsyncSyncFaulted_WrapsExceptionInWebSocketException() + { + var underlying = new IOException("flush failed"); + using var stream = new WebSocketTestStream { FlushException = underlying }; + using WebSocket ws = WebSocket.CreateFromStream( + stream, isServer: false, subProtocol: null, keepAliveInterval: Timeout.InfiniteTimeSpan); + + var buffer = new ArraySegment(new byte[] { 1, 2, 3 }); + + WebSocketException ex = await Assert.ThrowsAsync( + () => ws.SendAsync(buffer, WebSocketMessageType.Binary, endOfMessage: true, CancellationToken.None)); + + Assert.Equal(WebSocketError.ConnectionClosedPrematurely, ex.WebSocketErrorCode); + Assert.Same(underlying, ex.InnerException); + } + [Fact] public async Task ReceiveAsync_ServerUnmaskedFrame_ThrowsWebSocketException() {