diff --git a/src/Middleware/WebSockets/src/ServerWebSocket.cs b/src/Middleware/WebSockets/src/ServerWebSocket.cs
new file mode 100644
index 000000000000..1fe6351237d2
--- /dev/null
+++ b/src/Middleware/WebSockets/src/ServerWebSocket.cs
@@ -0,0 +1,82 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Net.WebSockets;
+using Microsoft.AspNetCore.Http;
+
+namespace Microsoft.AspNetCore.WebSockets
+{
+ ///
+ /// Used in ASP.NET Core to wrap a WebSocket with its associated HttpContext so that when the WebSocket is aborted
+ /// the underlying HttpContext is aborted. All other methods are delegated to the underlying WebSocket.
+ ///
+ internal sealed class ServerWebSocket : WebSocket
+ {
+ private readonly WebSocket _wrappedSocket;
+ private readonly HttpContext _context;
+
+ internal ServerWebSocket(WebSocket wrappedSocket, HttpContext context)
+ {
+ ArgumentNullException.ThrowIfNull(wrappedSocket);
+ ArgumentNullException.ThrowIfNull(context);
+
+ _wrappedSocket = wrappedSocket;
+ _context = context;
+ }
+
+ public override WebSocketCloseStatus? CloseStatus => _wrappedSocket.CloseStatus;
+
+ public override string? CloseStatusDescription => _wrappedSocket.CloseStatusDescription;
+
+ public override WebSocketState State => _wrappedSocket.State;
+
+ public override string? SubProtocol => _wrappedSocket.SubProtocol;
+
+ public override void Abort()
+ {
+ _wrappedSocket.Abort();
+ _context.Abort();
+ }
+
+ public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
+ {
+ return _wrappedSocket.CloseAsync(closeStatus, statusDescription, cancellationToken);
+ }
+
+ public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
+ {
+ return _wrappedSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken);
+ }
+
+ public override void Dispose()
+ {
+ _wrappedSocket.Dispose();
+ }
+
+ public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken)
+ {
+ return _wrappedSocket.ReceiveAsync(buffer, cancellationToken);
+ }
+
+ public override ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken)
+ {
+ return _wrappedSocket.ReceiveAsync(buffer, cancellationToken);
+ }
+
+ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
+ {
+ return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
+ }
+
+ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
+ {
+ return _wrappedSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken);
+ }
+
+ public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
+ {
+ return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
+ }
+ }
+}
+
diff --git a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs
index 6904450e6578..ff6a589bff18 100644
--- a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs
+++ b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs
@@ -194,13 +194,15 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext)
Stream opaqueTransport = await _upgradeFeature.UpgradeAsync(); // Sets status code to 101
- return WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
+ var wrappedSocket = WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
{
IsServer = true,
KeepAliveInterval = keepAliveInterval,
SubProtocol = subProtocol,
DangerousDeflateOptions = deflateOptions
});
+
+ return new ServerWebSocket(wrappedSocket, _context);
}
public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)
diff --git a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs
index aa67be817719..97395bdd819f 100644
--- a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs
+++ b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs
@@ -6,6 +6,7 @@
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
+using Microsoft.AspNetCore.Connections;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Testing;
@@ -499,6 +500,146 @@ public async Task CloseFromCloseReceived_Success()
}
}
+ [Fact]
+ public async Task WebSocket_Abort_Interrupts_Pending_ReceiveAsync()
+ {
+ WebSocket serverSocket = null;
+
+ // Events that we want to sequence execution across client and server.
+ var socketWasAccepted = new ManualResetEventSlim();
+ var socketWasAborted = new ManualResetEventSlim();
+ var firstReceiveOccured = new ManualResetEventSlim();
+ var secondReceiveInitiated = new ManualResetEventSlim();
+
+ Exception receiveException = null;
+
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ {
+ Assert.True(context.WebSockets.IsWebSocketRequest);
+ serverSocket = await context.WebSockets.AcceptWebSocketAsync();
+ socketWasAccepted.Set();
+
+ var serverBuffer = new byte[1024];
+
+ try
+ {
+ while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent)
+ {
+ if (firstReceiveOccured.IsSet)
+ {
+ var pendingResponse = serverSocket.ReceiveAsync(serverBuffer, default);
+ secondReceiveInitiated.Set();
+ var response = await pendingResponse;
+ }
+ else
+ {
+ var response = await serverSocket.ReceiveAsync(serverBuffer, default);
+ firstReceiveOccured.Set();
+ }
+ }
+ }
+ catch (ConnectionAbortedException ex)
+ {
+ socketWasAborted.Set();
+ receiveException = ex;
+ }
+ catch (Exception ex)
+ {
+ // Capture this exception so a test failure can give us more information.
+ receiveException = ex;
+ }
+ finally
+ {
+ Assert.IsType(receiveException);
+ }
+ }))
+ {
+ var clientBuffer = new byte[1024];
+
+ using (var client = new ClientWebSocket())
+ {
+ await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
+
+ var socketWasAcceptedDidNotTimeout = socketWasAccepted.Wait(10000);
+ Assert.True(socketWasAcceptedDidNotTimeout, "Socket was not accepted within the allotted time.");
+
+ await client.SendAsync(clientBuffer, WebSocketMessageType.Binary, false, default);
+
+ var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000);
+ Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time.");
+
+ var secondReceiveInitiatedDidNotTimeout = secondReceiveInitiated.Wait(10000);
+ Assert.True(secondReceiveInitiatedDidNotTimeout, "Second receive was not initiated within the allotted time.");
+
+ serverSocket.Abort();
+
+ var socketWasAbortedDidNotTimeout = socketWasAborted.Wait(1000); // Give it a second to process the abort.
+ Assert.True(socketWasAbortedDidNotTimeout, "Abort did not occur within the allotted time.");
+ }
+ }
+ }
+
+ [Fact]
+ public async Task WebSocket_AllowsCancelling_Pending_ReceiveAsync_When_CancellationTokenProvided()
+ {
+ WebSocket serverSocket = null;
+ CancellationTokenSource cts = new CancellationTokenSource();
+
+ var socketWasAccepted = new ManualResetEventSlim();
+ var operationWasCancelled = new ManualResetEventSlim();
+ var firstReceiveOccured = new ManualResetEventSlim();
+
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ {
+ Assert.True(context.WebSockets.IsWebSocketRequest);
+ serverSocket = await context.WebSockets.AcceptWebSocketAsync();
+ socketWasAccepted.Set();
+
+ var serverBuffer = new byte[1024];
+
+ var finishedWithOperationCancelled = false;
+
+ try
+ {
+ while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent)
+ {
+ var response = await serverSocket.ReceiveAsync(serverBuffer, cts.Token);
+ firstReceiveOccured.Set();
+ }
+ }
+ catch (OperationCanceledException)
+ {
+ operationWasCancelled.Set();
+ finishedWithOperationCancelled = true;
+ }
+ finally
+ {
+ Assert.True(finishedWithOperationCancelled);
+ }
+ }))
+ {
+ var clientBuffer = new byte[1024];
+
+ using (var client = new ClientWebSocket())
+ {
+ await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
+
+ var socketWasAcceptedDidNotTimeout = socketWasAccepted.Wait(10000);
+ Assert.True(socketWasAcceptedDidNotTimeout, "Socket was not accepted within the allotted time.");
+
+ await client.SendAsync(clientBuffer, WebSocketMessageType.Binary, false, default);
+
+ var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000);
+ Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time.");
+
+ cts.Cancel();
+
+ var operationWasCancelledDidNotTimeout = operationWasCancelled.Wait(1000); // Give it a second to process the abort.
+ Assert.True(operationWasCancelledDidNotTimeout, "Cancel did not occur within the allotted time.");
+ }
+ }
+ }
+
[Theory]
[InlineData(HttpStatusCode.OK, null)]
[InlineData(HttpStatusCode.Forbidden, "")]