diff --git a/src/Middleware/WebSockets/src/ServerWebSocket.cs b/src/Middleware/WebSockets/src/ServerWebSocket.cs new file mode 100644 index 000000000000..1fe6351237d2 --- /dev/null +++ b/src/Middleware/WebSockets/src/ServerWebSocket.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.WebSockets; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.WebSockets +{ + /// + /// Used in ASP.NET Core to wrap a WebSocket with its associated HttpContext so that when the WebSocket is aborted + /// the underlying HttpContext is aborted. All other methods are delegated to the underlying WebSocket. + /// + internal sealed class ServerWebSocket : WebSocket + { + private readonly WebSocket _wrappedSocket; + private readonly HttpContext _context; + + internal ServerWebSocket(WebSocket wrappedSocket, HttpContext context) + { + ArgumentNullException.ThrowIfNull(wrappedSocket); + ArgumentNullException.ThrowIfNull(context); + + _wrappedSocket = wrappedSocket; + _context = context; + } + + public override WebSocketCloseStatus? CloseStatus => _wrappedSocket.CloseStatus; + + public override string? CloseStatusDescription => _wrappedSocket.CloseStatusDescription; + + public override WebSocketState State => _wrappedSocket.State; + + public override string? SubProtocol => _wrappedSocket.SubProtocol; + + public override void Abort() + { + _wrappedSocket.Abort(); + _context.Abort(); + } + + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) + { + return _wrappedSocket.CloseAsync(closeStatus, statusDescription, cancellationToken); + } + + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) + { + return _wrappedSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken); + } + + public override void Dispose() + { + _wrappedSocket.Dispose(); + } + + public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + { + return _wrappedSocket.ReceiveAsync(buffer, cancellationToken); + } + + public override ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) + { + return _wrappedSocket.ReceiveAsync(buffer, cancellationToken); + } + + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + { + return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + } + + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) + { + return _wrappedSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken); + } + + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + { + return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + } + } +} + diff --git a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs index 6904450e6578..ff6a589bff18 100644 --- a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs +++ b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs @@ -194,13 +194,15 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext) Stream opaqueTransport = await _upgradeFeature.UpgradeAsync(); // Sets status code to 101 - return WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions() + var wrappedSocket = WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions() { IsServer = true, KeepAliveInterval = keepAliveInterval, SubProtocol = subProtocol, DangerousDeflateOptions = deflateOptions }); + + return new ServerWebSocket(wrappedSocket, _context); } public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders) diff --git a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs index aa67be817719..97395bdd819f 100644 --- a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs +++ b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs @@ -6,6 +6,7 @@ using System.Net.Http; using System.Net.WebSockets; using System.Text; +using Microsoft.AspNetCore.Connections; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Testing; @@ -499,6 +500,146 @@ public async Task CloseFromCloseReceived_Success() } } + [Fact] + public async Task WebSocket_Abort_Interrupts_Pending_ReceiveAsync() + { + WebSocket serverSocket = null; + + // Events that we want to sequence execution across client and server. + var socketWasAccepted = new ManualResetEventSlim(); + var socketWasAborted = new ManualResetEventSlim(); + var firstReceiveOccured = new ManualResetEventSlim(); + var secondReceiveInitiated = new ManualResetEventSlim(); + + Exception receiveException = null; + + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + { + Assert.True(context.WebSockets.IsWebSocketRequest); + serverSocket = await context.WebSockets.AcceptWebSocketAsync(); + socketWasAccepted.Set(); + + var serverBuffer = new byte[1024]; + + try + { + while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent) + { + if (firstReceiveOccured.IsSet) + { + var pendingResponse = serverSocket.ReceiveAsync(serverBuffer, default); + secondReceiveInitiated.Set(); + var response = await pendingResponse; + } + else + { + var response = await serverSocket.ReceiveAsync(serverBuffer, default); + firstReceiveOccured.Set(); + } + } + } + catch (ConnectionAbortedException ex) + { + socketWasAborted.Set(); + receiveException = ex; + } + catch (Exception ex) + { + // Capture this exception so a test failure can give us more information. + receiveException = ex; + } + finally + { + Assert.IsType(receiveException); + } + })) + { + var clientBuffer = new byte[1024]; + + using (var client = new ClientWebSocket()) + { + await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); + + var socketWasAcceptedDidNotTimeout = socketWasAccepted.Wait(10000); + Assert.True(socketWasAcceptedDidNotTimeout, "Socket was not accepted within the allotted time."); + + await client.SendAsync(clientBuffer, WebSocketMessageType.Binary, false, default); + + var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000); + Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time."); + + var secondReceiveInitiatedDidNotTimeout = secondReceiveInitiated.Wait(10000); + Assert.True(secondReceiveInitiatedDidNotTimeout, "Second receive was not initiated within the allotted time."); + + serverSocket.Abort(); + + var socketWasAbortedDidNotTimeout = socketWasAborted.Wait(1000); // Give it a second to process the abort. + Assert.True(socketWasAbortedDidNotTimeout, "Abort did not occur within the allotted time."); + } + } + } + + [Fact] + public async Task WebSocket_AllowsCancelling_Pending_ReceiveAsync_When_CancellationTokenProvided() + { + WebSocket serverSocket = null; + CancellationTokenSource cts = new CancellationTokenSource(); + + var socketWasAccepted = new ManualResetEventSlim(); + var operationWasCancelled = new ManualResetEventSlim(); + var firstReceiveOccured = new ManualResetEventSlim(); + + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + { + Assert.True(context.WebSockets.IsWebSocketRequest); + serverSocket = await context.WebSockets.AcceptWebSocketAsync(); + socketWasAccepted.Set(); + + var serverBuffer = new byte[1024]; + + var finishedWithOperationCancelled = false; + + try + { + while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent) + { + var response = await serverSocket.ReceiveAsync(serverBuffer, cts.Token); + firstReceiveOccured.Set(); + } + } + catch (OperationCanceledException) + { + operationWasCancelled.Set(); + finishedWithOperationCancelled = true; + } + finally + { + Assert.True(finishedWithOperationCancelled); + } + })) + { + var clientBuffer = new byte[1024]; + + using (var client = new ClientWebSocket()) + { + await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); + + var socketWasAcceptedDidNotTimeout = socketWasAccepted.Wait(10000); + Assert.True(socketWasAcceptedDidNotTimeout, "Socket was not accepted within the allotted time."); + + await client.SendAsync(clientBuffer, WebSocketMessageType.Binary, false, default); + + var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000); + Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time."); + + cts.Cancel(); + + var operationWasCancelledDidNotTimeout = operationWasCancelled.Wait(1000); // Give it a second to process the abort. + Assert.True(operationWasCancelledDidNotTimeout, "Cancel did not occur within the allotted time."); + } + } + } + [Theory] [InlineData(HttpStatusCode.OK, null)] [InlineData(HttpStatusCode.Forbidden, "")]