From aeb9023b52d3756b0a593d485fc2732cf00f9258 Mon Sep 17 00:00:00 2001 From: Mitch Denny Date: Tue, 20 Jun 2023 16:06:51 +1000 Subject: [PATCH 1/2] Backport PR 48892. --- .../WebSockets/src/ServerWebSocket.cs | 80 ++++++++++++ .../WebSockets/src/WebSocketMiddleware.cs | 4 +- .../UnitTests/WebSocketMiddlewareTests.cs | 122 ++++++++++++++++++ 3 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 src/Middleware/WebSockets/src/ServerWebSocket.cs diff --git a/src/Middleware/WebSockets/src/ServerWebSocket.cs b/src/Middleware/WebSockets/src/ServerWebSocket.cs new file mode 100644 index 000000000000..70be31cb0459 --- /dev/null +++ b/src/Middleware/WebSockets/src/ServerWebSocket.cs @@ -0,0 +1,80 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.WebSockets; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.WebSockets; + +/// +/// Used in ASP.NET Core to wrap a WebSocket with its associated HttpContext so that when the WebSocket is aborted +/// the underlying HttpContext is aborted. All other methods are delegated to the underlying WebSocket. +/// +internal sealed class ServerWebSocket : WebSocket +{ + private readonly WebSocket _wrappedSocket; + private readonly HttpContext _context; + + internal ServerWebSocket(WebSocket wrappedSocket, HttpContext context) + { + ArgumentNullException.ThrowIfNull(wrappedSocket); + ArgumentNullException.ThrowIfNull(context); + + _wrappedSocket = wrappedSocket; + _context = context; + } + + public override WebSocketCloseStatus? CloseStatus => _wrappedSocket.CloseStatus; + + public override string? CloseStatusDescription => _wrappedSocket.CloseStatusDescription; + + public override WebSocketState State => _wrappedSocket.State; + + public override string? SubProtocol => _wrappedSocket.SubProtocol; + + public override void Abort() + { + _wrappedSocket.Abort(); + _context.Abort(); + } + + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) + { + return _wrappedSocket.CloseAsync(closeStatus, statusDescription, cancellationToken); + } + + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) + { + return _wrappedSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken); + } + + public override void Dispose() + { + _wrappedSocket.Dispose(); + } + + public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + { + return _wrappedSocket.ReceiveAsync(buffer, cancellationToken); + } + + public override ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) + { + return _wrappedSocket.ReceiveAsync(buffer, cancellationToken); + } + + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + { + return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + } + + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) + { + return _wrappedSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken); + } + + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + { + return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + } +} diff --git a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs index 8f2bcca80ef6..0d9f61947c20 100644 --- a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs +++ b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs @@ -207,13 +207,15 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext) opaqueTransport = await _upgradeFeature!.UpgradeAsync(); // Sets status code to 101 } - return WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions() + var wrappedSocket = WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions() { IsServer = true, KeepAliveInterval = keepAliveInterval, SubProtocol = subProtocol, DangerousDeflateOptions = deflateOptions }); + + return new ServerWebSocket(wrappedSocket, _context); } public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders) diff --git a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs index 90fa6f08c9c0..c3935cc7d114 100644 --- a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs +++ b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs @@ -5,6 +5,7 @@ using System.Net.Http; using System.Net.WebSockets; using System.Text; +using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Testing; using Microsoft.Net.Http.Headers; @@ -495,6 +496,127 @@ public async Task CloseFromCloseReceived_Success() } } + [Fact] + public async Task WebSocket_Abort_Interrupts_Pending_ReceiveAsync() + { + WebSocket serverSocket = null; + + var socketWasAccepted = new ManualResetEventSlim(); + var socketWasAborted = new ManualResetEventSlim(); + var firstReceiveOccured = new ManualResetEventSlim(); + + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + { + Assert.True(context.WebSockets.IsWebSocketRequest); + serverSocket = await context.WebSockets.AcceptWebSocketAsync(); + socketWasAccepted.Set(); + + var serverBuffer = new byte[1024]; + + var finishedWithConnectionAborted = false; + + try + { + while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent) + { + var response = await serverSocket.ReceiveAsync(serverBuffer, default); + firstReceiveOccured.Set(); + } + } + catch (ConnectionAbortedException) + { + socketWasAborted.Set(); + finishedWithConnectionAborted = true; + } + finally + { + Assert.True(finishedWithConnectionAborted); + } + })) + { + var clientBuffer = new byte[1024]; + + using (var client = new ClientWebSocket()) + { + await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); + + var socketWasAcceptedDidNotTimeout = socketWasAccepted.Wait(10000); + Assert.True(socketWasAcceptedDidNotTimeout, "Socket was not accepted within the allotted time."); + + await client.SendAsync(clientBuffer, WebSocketMessageType.Binary, false, default); + + var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000); + Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time."); + + serverSocket.Abort(); + + var socketWasAbortedDidNotTimeout = socketWasAborted.Wait(1000); // Give it a second to process the abort. + Assert.True(socketWasAbortedDidNotTimeout, "Abort did not occur within the allotted time."); + } + } + } + + [Fact] + public async Task WebSocket_AllowsCancelling_Pending_ReceiveAsync_When_CancellationTokenProvided() + { + WebSocket serverSocket = null; + CancellationTokenSource cts = new CancellationTokenSource(); + + var socketWasAccepted = new ManualResetEventSlim(); + var operationWasCancelled = new ManualResetEventSlim(); + var firstReceiveOccured = new ManualResetEventSlim(); + + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + { + Assert.True(context.WebSockets.IsWebSocketRequest); + serverSocket = await context.WebSockets.AcceptWebSocketAsync(); + socketWasAccepted.Set(); + + var serverBuffer = new byte[1024]; + + var finishedWithOperationCancelled = false; + + try + { + while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent) + { + var response = await serverSocket.ReceiveAsync(serverBuffer, cts.Token); + firstReceiveOccured.Set(); + } + } + catch (OperationCanceledException) + { + operationWasCancelled.Set(); + finishedWithOperationCancelled = true; + } + finally + { + Assert.True(finishedWithOperationCancelled); + } + })) + { + var clientBuffer = new byte[1024]; + + using (var client = new ClientWebSocket()) + { + await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); + + var socketWasAcceptedDidNotTimeout = socketWasAccepted.Wait(10000); + Assert.True(socketWasAcceptedDidNotTimeout, "Socket was not accepted within the allotted time."); + + await client.SendAsync(clientBuffer, WebSocketMessageType.Binary, false, default); + + var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000); + Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time."); + + cts.Cancel(); + + var operationWasCancelledDidNotTimeout = operationWasCancelled.Wait(1000); // Give it a second to process the abort. + Assert.True(operationWasCancelledDidNotTimeout, "Cancel did not occur within the allotted time."); + } + } + } + [Theory] [InlineData(HttpStatusCode.OK, null)] [InlineData(HttpStatusCode.Forbidden, "")] From ca5540b43e52a41fd5e90e86a67b34ef08a0488d Mon Sep 17 00:00:00 2001 From: Mitch Denny Date: Fri, 23 Jun 2023 08:57:37 +1000 Subject: [PATCH 2/2] Fix bug found in unit test on main. --- .../UnitTests/WebSocketMiddlewareTests.cs | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs index c3935cc7d114..a89b161212c6 100644 --- a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs +++ b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs @@ -501,9 +501,13 @@ public async Task WebSocket_Abort_Interrupts_Pending_ReceiveAsync() { WebSocket serverSocket = null; + // Events that we want to sequence execution across client and server. var socketWasAccepted = new ManualResetEventSlim(); var socketWasAborted = new ManualResetEventSlim(); var firstReceiveOccured = new ManualResetEventSlim(); + var secondReceiveInitiated = new ManualResetEventSlim(); + + Exception receiveException = null; await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { @@ -513,24 +517,36 @@ public async Task WebSocket_Abort_Interrupts_Pending_ReceiveAsync() var serverBuffer = new byte[1024]; - var finishedWithConnectionAborted = false; - try { while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent) { - var response = await serverSocket.ReceiveAsync(serverBuffer, default); - firstReceiveOccured.Set(); + if (firstReceiveOccured.IsSet) + { + var pendingResponse = serverSocket.ReceiveAsync(serverBuffer, default); + secondReceiveInitiated.Set(); + var response = await pendingResponse; + } + else + { + var response = await serverSocket.ReceiveAsync(serverBuffer, default); + firstReceiveOccured.Set(); + } } } - catch (ConnectionAbortedException) + catch (ConnectionAbortedException ex) { socketWasAborted.Set(); - finishedWithConnectionAborted = true; + receiveException = ex; + } + catch (Exception ex) + { + // Capture this exception so a test failure can give us more information. + receiveException = ex; } finally { - Assert.True(finishedWithConnectionAborted); + Assert.IsType(receiveException); } })) { @@ -548,6 +564,9 @@ public async Task WebSocket_Abort_Interrupts_Pending_ReceiveAsync() var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000); Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time."); + var secondReceiveInitiatedDidNotTimeout = secondReceiveInitiated.Wait(10000); + Assert.True(secondReceiveInitiatedDidNotTimeout, "Second receive was not initiated within the allotted time."); + serverSocket.Abort(); var socketWasAbortedDidNotTimeout = socketWasAborted.Wait(1000); // Give it a second to process the abort.