Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/Middleware/WebSockets/src/ServerWebSocket.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Net.WebSockets;
using Microsoft.AspNetCore.Http;

namespace Microsoft.AspNetCore.WebSockets;

/// <summary>
/// Used in ASP.NET Core to wrap a WebSocket with its associated HttpContext so that when the WebSocket is aborted
/// the underlying HttpContext is aborted. All other methods are delegated to the underlying WebSocket.
/// </summary>
internal sealed class ServerWebSocket : WebSocket
{
private readonly WebSocket _wrappedSocket;
private readonly HttpContext _context;

internal ServerWebSocket(WebSocket wrappedSocket, HttpContext context)
{
ArgumentNullException.ThrowIfNull(wrappedSocket);
ArgumentNullException.ThrowIfNull(context);

_wrappedSocket = wrappedSocket;
_context = context;
}

public override WebSocketCloseStatus? CloseStatus => _wrappedSocket.CloseStatus;

public override string? CloseStatusDescription => _wrappedSocket.CloseStatusDescription;

public override WebSocketState State => _wrappedSocket.State;

public override string? SubProtocol => _wrappedSocket.SubProtocol;

public override void Abort()
{
_wrappedSocket.Abort();
_context.Abort();
}

public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
{
return _wrappedSocket.CloseAsync(closeStatus, statusDescription, cancellationToken);
}

public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
{
return _wrappedSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken);
}

public override void Dispose()
{
_wrappedSocket.Dispose();
}

public override Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
{
return _wrappedSocket.ReceiveAsync(buffer, cancellationToken);
}

public override ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(Memory<byte> buffer, CancellationToken cancellationToken)
{
return _wrappedSocket.ReceiveAsync(buffer, cancellationToken);
}

public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
{
return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
}

public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
{
return _wrappedSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken);
}

public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
{
return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
}
}
4 changes: 3 additions & 1 deletion src/Middleware/WebSockets/src/WebSocketMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,15 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
opaqueTransport = await _upgradeFeature!.UpgradeAsync(); // Sets status code to 101
}

return WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
var wrappedSocket = WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
{
IsServer = true,
KeepAliveInterval = keepAliveInterval,
SubProtocol = subProtocol,
DangerousDeflateOptions = deflateOptions
});

return new ServerWebSocket(wrappedSocket, _context);
}

public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)
Expand Down
141 changes: 141 additions & 0 deletions src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Testing;
using Microsoft.Net.Http.Headers;

Expand Down Expand Up @@ -495,6 +496,146 @@ public async Task CloseFromCloseReceived_Success()
}
}

[Fact]
public async Task WebSocket_Abort_Interrupts_Pending_ReceiveAsync()
{
WebSocket serverSocket = null;

// Events that we want to sequence execution across client and server.
var socketWasAccepted = new ManualResetEventSlim();
var socketWasAborted = new ManualResetEventSlim();
var firstReceiveOccured = new ManualResetEventSlim();
var secondReceiveInitiated = new ManualResetEventSlim();

Exception receiveException = null;

await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
serverSocket = await context.WebSockets.AcceptWebSocketAsync();
socketWasAccepted.Set();

var serverBuffer = new byte[1024];

try
{
while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent)
{
if (firstReceiveOccured.IsSet)
{
var pendingResponse = serverSocket.ReceiveAsync(serverBuffer, default);
secondReceiveInitiated.Set();
var response = await pendingResponse;
}
else
{
var response = await serverSocket.ReceiveAsync(serverBuffer, default);
firstReceiveOccured.Set();
}
}
}
catch (ConnectionAbortedException ex)
{
socketWasAborted.Set();
receiveException = ex;
}
catch (Exception ex)
{
// Capture this exception so a test failure can give us more information.
receiveException = ex;
}
finally
{
Assert.IsType<ConnectionAbortedException>(receiveException);
}
}))
{
var clientBuffer = new byte[1024];

using (var client = new ClientWebSocket())
{
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);

var socketWasAcceptedDidNotTimeout = socketWasAccepted.Wait(10000);
Assert.True(socketWasAcceptedDidNotTimeout, "Socket was not accepted within the allotted time.");

await client.SendAsync(clientBuffer, WebSocketMessageType.Binary, false, default);

var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000);
Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time.");

var secondReceiveInitiatedDidNotTimeout = secondReceiveInitiated.Wait(10000);
Assert.True(secondReceiveInitiatedDidNotTimeout, "Second receive was not initiated within the allotted time.");

serverSocket.Abort();

var socketWasAbortedDidNotTimeout = socketWasAborted.Wait(1000); // Give it a second to process the abort.
Assert.True(socketWasAbortedDidNotTimeout, "Abort did not occur within the allotted time.");
}
}
}

[Fact]
public async Task WebSocket_AllowsCancelling_Pending_ReceiveAsync_When_CancellationTokenProvided()
{
WebSocket serverSocket = null;
CancellationTokenSource cts = new CancellationTokenSource();

var socketWasAccepted = new ManualResetEventSlim();
var operationWasCancelled = new ManualResetEventSlim();
var firstReceiveOccured = new ManualResetEventSlim();

await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
serverSocket = await context.WebSockets.AcceptWebSocketAsync();
socketWasAccepted.Set();

var serverBuffer = new byte[1024];

var finishedWithOperationCancelled = false;

try
{
while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent)
{
var response = await serverSocket.ReceiveAsync(serverBuffer, cts.Token);
firstReceiveOccured.Set();
}
}
catch (OperationCanceledException)
{
operationWasCancelled.Set();
finishedWithOperationCancelled = true;
}
finally
{
Assert.True(finishedWithOperationCancelled);
}
}))
{
var clientBuffer = new byte[1024];

using (var client = new ClientWebSocket())
{
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);

var socketWasAcceptedDidNotTimeout = socketWasAccepted.Wait(10000);
Assert.True(socketWasAcceptedDidNotTimeout, "Socket was not accepted within the allotted time.");

await client.SendAsync(clientBuffer, WebSocketMessageType.Binary, false, default);

var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000);
Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time.");

cts.Cancel();

var operationWasCancelledDidNotTimeout = operationWasCancelled.Wait(1000); // Give it a second to process the abort.
Assert.True(operationWasCancelledDidNotTimeout, "Cancel did not occur within the allotted time.");
}
}
}

[Theory]
[InlineData(HttpStatusCode.OK, null)]
[InlineData(HttpStatusCode.Forbidden, "")]
Expand Down