diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/BrowserWebSocket.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/BrowserWebSocket.cs index 4288444f2de37b..6df638d5b429d7 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/BrowserWebSocket.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/BrowserWebSocket.cs @@ -34,6 +34,7 @@ internal sealed class BrowserWebSocket : WebSocket }); private TaskCompletionSource? _tcsClose; + private TaskCompletionSource? _tcsConnect; private WebSocketCloseStatus? _innerWebSocketCloseStatus; private string? _innerWebSocketCloseStatusDescription; @@ -47,6 +48,7 @@ internal sealed class BrowserWebSocket : WebSocket private MemoryStream? _writeBuffer; private ReceivePayload? _bufferedPayload; private readonly CancellationTokenSource _cts; + private int _closeStatus; // variable to track the close status after a close is sent. // Stages of this class. private int _state; @@ -56,9 +58,12 @@ private enum InternalState Created = 0, Connecting = 1, Connected = 2, - Disposed = 3 + CloseSent = 3, + Disposed = 4, + Aborted = 5, } + private bool _disposed; /// /// Initializes a new instance of the class. @@ -78,7 +83,7 @@ public override WebSocketState State { get { - if (_innerWebSocket != null && !_innerWebSocket.IsDisposed) + if (_innerWebSocket != null && !_innerWebSocket.IsDisposed && _state != (int)InternalState.Aborted) { return ReadyStateToDotNetState((int)_innerWebSocket.GetObjectProperty("readyState")); } @@ -86,6 +91,9 @@ public override WebSocketState State { InternalState.Created => WebSocketState.None, InternalState.Connecting => WebSocketState.Connecting, + InternalState.Aborted => WebSocketState.Aborted, + InternalState.Disposed => WebSocketState.Closed, + InternalState.CloseSent => WebSocketState.CloseSent, _ => WebSocketState.Closed }; } @@ -112,19 +120,24 @@ private static WebSocketState ReadyStateToDotNetState(int readyState) => internal async Task ConnectAsyncJavaScript(Uri uri, CancellationToken cancellationToken, List? requestedSubProtocols) { - // Check that we have not started already - int priorState = Interlocked.CompareExchange(ref _state, (int)InternalState.Connecting, (int)InternalState.Created); - if (priorState == (int)InternalState.Disposed) - { - throw new ObjectDisposedException(GetType().FullName); - } - else if (priorState != (int)InternalState.Created) + // Check that we have not started already. + int prevState = _state; + _state = _state == (int)InternalState.Created ? (int)InternalState.Connecting : _state; + + switch ((InternalState)prevState) { - throw new InvalidOperationException(SR.net_WebSockets_AlreadyStarted); + case InternalState.Disposed: + throw new ObjectDisposedException(GetType().FullName); + + case InternalState.Created: + break; + + default: + throw new InvalidOperationException(SR.net_WebSockets_AlreadyStarted); } CancellationTokenRegistration connectRegistration = cancellationToken.Register(cts => ((CancellationTokenSource)cts!).Cancel(), _cts); - TaskCompletionSource tcsConnect = new TaskCompletionSource(); + _tcsConnect = new TaskCompletionSource(); // For Abort/Dispose. Calling Abort on the request at any point will close the connection. _cts.Token.Register(s => ((BrowserWebSocket)s!).AbortRequest(), this); @@ -163,20 +176,21 @@ internal async Task ConnectAsyncJavaScript(Uri uri, CancellationToken cancellati _innerWebSocketCloseStatusDescription = closeEvt.GetObjectProperty("reason")?.ToString(); _receiveMessageQueue.Writer.TryWrite(new ReceivePayload(Array.Empty(), WebSocketMessageType.Close)); NativeCleanup(); - if ((InternalState)_state == InternalState.Connecting) + if ((InternalState)_state == InternalState.Connecting || (InternalState)_state == InternalState.Aborted) { + _state = (int)InternalState.Disposed; if (cancellationToken.IsCancellationRequested) { - tcsConnect.TrySetCanceled(cancellationToken); + _tcsConnect.TrySetCanceled(cancellationToken); } else { - tcsConnect.TrySetException(new WebSocketException(WebSocketError.NativeError)); + _tcsConnect.TrySetException(new WebSocketException(WebSocketError.NativeError)); } } else { - _tcsClose?.SetResult(); + _tcsClose?.TrySetResult(); } } }; @@ -192,19 +206,21 @@ internal async Task ConnectAsyncJavaScript(Uri uri, CancellationToken cancellati if (!cancellationToken.IsCancellationRequested) { // Change internal _state to 'Connected' to enable the other methods - if (Interlocked.CompareExchange(ref _state, (int)InternalState.Connected, (int)InternalState.Connecting) != (int)InternalState.Connecting) + int prevState = _state; + _state = _state == (int)InternalState.Connecting ? (int)InternalState.Connected : _state; + if (prevState != (int)InternalState.Connecting) { // Aborted/Disposed during connect. - tcsConnect.TrySetException(new ObjectDisposedException(GetType().FullName)); + _tcsConnect.TrySetException(new ObjectDisposedException(GetType().FullName)); } else { - tcsConnect.SetResult(); + _tcsConnect.SetResult(); } } else { - tcsConnect.SetCanceled(cancellationToken); + _tcsConnect.SetCanceled(cancellationToken); } } }; @@ -217,7 +233,7 @@ internal async Task ConnectAsyncJavaScript(Uri uri, CancellationToken cancellati // Attach the onMessage callaback _innerWebSocket.SetObjectProperty("onmessage", _onMessage); - await tcsConnect.Task.ConfigureAwait(continueOnCapturedContext: true); + await _tcsConnect.Task.ConfigureAwait(continueOnCapturedContext: true); } catch (Exception wse) { @@ -227,7 +243,7 @@ internal async Task ConnectAsyncJavaScript(Uri uri, CancellationToken cancellati case OperationCanceledException: throw; default: - throw new WebSocketException(SR.net_webstatus_ConnectFailure, wse); + throw new WebSocketException(WebSocketError.Faulted, SR.net_webstatus_ConnectFailure, wse); } } finally @@ -318,32 +334,51 @@ private void NativeCleanup() public override void Dispose() { - int priorState = Interlocked.Exchange(ref _state, (int)InternalState.Disposed); - if (priorState == (int)InternalState.Disposed) + if (!_disposed) { - // No cleanup required. - return; - } + if (_state < (int)InternalState.Aborted) { + _state = (int)InternalState.Disposed; + } + _disposed = true; - // registered by the CancellationTokenSource cts in the connect method - _cts.Cancel(false); - _cts.Dispose(); + if (!_cts.IsCancellationRequested) + { + // registered by the CancellationTokenSource cts in the connect method + _cts.Cancel(false); + _cts.Dispose(); + } - _writeBuffer?.Dispose(); - _receiveMessageQueue.Writer.Complete(); + _writeBuffer?.Dispose(); + _receiveMessageQueue.Writer.TryComplete(); - NativeCleanup(); + NativeCleanup(); - _innerWebSocket?.Dispose(); + _innerWebSocket?.Dispose(); + } } // This method is registered by the CancellationTokenSource cts in the connect method // and called by Dispose or Abort so that any open websocket connection can be closed. private async void AbortRequest() { - if (State == WebSocketState.Open || State == WebSocketState.Connecting) + switch (State) { - await CloseAsyncCore(WebSocketCloseStatus.NormalClosure, SR.net_WebSockets_Connection_Aborted, CancellationToken.None).ConfigureAwait(continueOnCapturedContext: true); + case WebSocketState.Open: + case WebSocketState.Connecting: + { + await CloseAsyncCore(WebSocketCloseStatus.NormalClosure, SR.net_WebSockets_Connection_Aborted, CancellationToken.None).ConfigureAwait(continueOnCapturedContext: true); + // The following code is for those browsers that do not set Close and send an onClose event in certain instances i.e. firefox and safari. + // chrome will send an onClose event and we tear down the websocket there. + if (ReadyStateToDotNetState(_closeStatus) == WebSocketState.CloseSent) + { + _writeBuffer?.Dispose(); + _receiveMessageQueue.Writer.TryWrite(new ReceivePayload(Array.Empty(), WebSocketMessageType.Close)); + _receiveMessageQueue.Writer.TryComplete(); + NativeCleanup(); + _tcsConnect?.TrySetCanceled(); + } + } + break; } } @@ -423,6 +458,20 @@ public override Task SendAsync(ArraySegment buffer, WebSocketMessageType m return Task.CompletedTask; } + // This method is registered by the CancellationTokenSource in the receive async method + private async void CancelRequest() + { + int prevState = _state; + _state = (int)InternalState.Aborted; + _receiveMessageQueue.Writer.TryComplete(); + if (prevState == (int)InternalState.Connected || prevState == (int)InternalState.Connecting) + { + if (prevState == (int)InternalState.Connecting) + _state = (int)InternalState.CloseSent; + await CloseAsyncCore(WebSocketCloseStatus.NormalClosure, SR.net_WebSockets_Connection_Aborted, CancellationToken.None).ConfigureAwait(continueOnCapturedContext: true); + } + } + /// /// Receives data on as an asynchronous operation. /// @@ -431,22 +480,43 @@ public override Task SendAsync(ArraySegment buffer, WebSocketMessageType m /// Cancellation token. public override async Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) { - WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); - ThrowIfDisposed(); - ThrowOnInvalidState(State, WebSocketState.Open, WebSocketState.CloseSent); - _bufferedPayload ??= await _receiveMessageQueue.Reader.ReadAsync(cancellationToken).ConfigureAwait(continueOnCapturedContext: true); + if (cancellationToken.IsCancellationRequested) + { + return await Task.FromException(new OperationCanceledException()).ConfigureAwait(continueOnCapturedContext: true); + } + + CancellationTokenSource _receiveCTS = new CancellationTokenSource(); + CancellationTokenRegistration receiveRegistration = cancellationToken.Register(cts => ((CancellationTokenSource)cts!).Cancel(), _receiveCTS); + _receiveCTS.Token.Register(s => ((BrowserWebSocket)s!).CancelRequest(), this); try { - bool endOfMessage = _bufferedPayload.BufferPayload(buffer, out WebSocketReceiveResult receiveResult); + WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); + + ThrowIfDisposed(); + ThrowOnInvalidState(State, WebSocketState.Open, WebSocketState.CloseSent); + _bufferedPayload ??= await _receiveMessageQueue.Reader.ReadAsync(cancellationToken).ConfigureAwait(continueOnCapturedContext: true); + bool endOfMessage = _bufferedPayload!.BufferPayload(buffer, out WebSocketReceiveResult receiveResult); if (endOfMessage) _bufferedPayload = null; return receiveResult; } catch (Exception exc) { - throw new WebSocketException(WebSocketError.NativeError, exc); + switch (exc) + { + case OperationCanceledException: + return await Task.FromException(exc).ConfigureAwait(continueOnCapturedContext: true); + case ChannelClosedException: + return await Task.FromException(new WebSocketException(WebSocketError.InvalidState, SR.Format(SR.net_WebSockets_InvalidState, State, "Open, CloseSent"))).ConfigureAwait(continueOnCapturedContext: true); + default: + return await Task.FromException(new WebSocketException(WebSocketError.InvalidState, SR.Format(SR.net_WebSockets_InvalidState, State, "Open, CloseSent"))).ConfigureAwait(continueOnCapturedContext: true); + } + } + finally + { + receiveRegistration.Unregister(); } } @@ -455,12 +525,20 @@ public override async Task ReceiveAsync(ArraySegment public override void Abort() { - if (_state == (int)InternalState.Disposed) + if (_state != (int)InternalState.Disposed) { - return; + int prevState = _state; + if (prevState != (int)InternalState.Connecting) + { + _state = (int)InternalState.Aborted; + } + + if (prevState < (int)InternalState.Aborted) + { + _cts.Cancel(true); + _tcsClose?.TrySetResult(); + } } - _state = (int)WebSocketState.Aborted; - Dispose(); } public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) @@ -478,7 +556,6 @@ public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? status { return Task.FromException(exc); } - return CloseAsyncCore(closeStatus, statusDescription, cancellationToken); } @@ -490,6 +567,7 @@ private Task CloseAsyncCore(WebSocketCloseStatus closeStatus, string? statusDesc _innerWebSocketCloseStatus = closeStatus; _innerWebSocketCloseStatusDescription = statusDescription; _innerWebSocket!.Invoke("close", (int)closeStatus, statusDescription); + _closeStatus = (int)_innerWebSocket.GetObjectProperty("readyState"); return _tcsClose.Task; } catch (Exception exc) diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Browser.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Browser.cs index cf91e6821db05e..1dde3894c8dc29 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Browser.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Browser.cs @@ -24,6 +24,7 @@ public void Dispose() public void Abort() { + _abortSource.Cancel(); WebSocket?.Abort(); } @@ -67,7 +68,7 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli case OperationCanceledException _ when cancellationToken.IsCancellationRequested: throw; default: - throw new WebSocketException(SR.net_webstatus_ConnectFailure, exc); + throw new WebSocketException(WebSocketError.Faulted, SR.net_webstatus_ConnectFailure, exc); } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.cs index 6b48fd285258af..565150cbe91a09 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.cs @@ -17,7 +17,6 @@ public AbortTest(ITestOutputHelper output) : base(output) { } [OuterLoop] [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] - [ActiveIssue("https://github.com/dotnet/runtime/issues/45674", TestPlatforms.Browser)] public async Task Abort_ConnectAndAbort_ThrowsWebSocketExceptionWithmessage(Uri server) { using (var cws = new ClientWebSocket()) @@ -43,7 +42,6 @@ public async Task Abort_ConnectAndAbort_ThrowsWebSocketExceptionWithmessage(Uri [OuterLoop] [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] - [ActiveIssue("https://github.com/dotnet/runtime/issues/45674", TestPlatforms.Browser)] public async Task Abort_SendAndAbort_Success(Uri server) { await TestCancellation(async (cws) => @@ -64,7 +62,6 @@ await TestCancellation(async (cws) => [OuterLoop] [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] - [ActiveIssue("https://github.com/dotnet/runtime/issues/45674", TestPlatforms.Browser)] public async Task Abort_ReceiveAndAbort_Success(Uri server) { await TestCancellation(async (cws) => @@ -89,7 +86,6 @@ await cws.SendAsync( [OuterLoop] [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] - [ActiveIssue("https://github.com/dotnet/runtime/issues/45674", TestPlatforms.Browser)] public async Task Abort_CloseAndAbort_Success(Uri server) { await TestCancellation(async (cws) => diff --git a/src/libraries/System.Net.WebSockets.Client/tests/CancelTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/CancelTest.cs index 32780301a3644b..92c44a712378a1 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/CancelTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/CancelTest.cs @@ -131,7 +131,6 @@ public async Task ReceiveAsync_CancelThenReceive_ThrowsOperationCanceledExceptio [OuterLoop("Uses external servers")] [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] - [ActiveIssue("https://github.com/dotnet/runtime/issues/45674", TestPlatforms.Browser)] public async Task ReceiveAsync_ReceiveThenCancel_ThrowsOperationCanceledException(Uri server) { using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(server, TimeOutMilliseconds, _output)) @@ -148,7 +147,6 @@ public async Task ReceiveAsync_ReceiveThenCancel_ThrowsOperationCanceledExceptio [OuterLoop("Uses external servers")] [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] - [ActiveIssue("https://github.com/dotnet/runtime/issues/45674", TestPlatforms.Browser)] public async Task ReceiveAsync_AfterCancellationDoReceiveAsync_ThrowsWebSocketException(Uri server) { using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(server, TimeOutMilliseconds, _output)) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs index ca24740d69cd1e..f521a8a4769544 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs @@ -26,6 +26,7 @@ public class ClientWebSocketTestBase }).ToArray(); public const int TimeOutMilliseconds = 20000; + public const int BrowserTimeOutMilliseconds = 30000; public const int CloseDescriptionMaxLength = 123; public readonly ITestOutputHelper _output; diff --git a/src/libraries/System.Net.WebSockets.Client/tests/SendReceiveTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/SendReceiveTest.cs index 7999e4d636c295..49a4e3cb84fb7a 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/SendReceiveTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/SendReceiveTest.cs @@ -88,7 +88,7 @@ public async Task SendReceive_PartialMessageDueToSmallReceiveBuffer_Success(Uri [OuterLoop("Uses external servers")] [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] - [ActiveIssue("https://github.com/dotnet/runtime/issues/45586", TestPlatforms.Browser)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/46983", TestPlatforms.Browser)] // JS Websocket does not support see issue public async Task SendReceive_PartialMessageBeforeCompleteMessageArrives_Success(Uri server) { var rand = new Random(); @@ -131,7 +131,6 @@ public async Task SendReceive_PartialMessageBeforeCompleteMessageArrives_Success [OuterLoop("Uses external servers")] [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] - [ActiveIssue("https://github.com/dotnet/runtime/issues/45586", TestPlatforms.Browser)] public async Task SendAsync_SendCloseMessageType_ThrowsArgumentExceptionWithMessage(Uri server) { using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(server, TimeOutMilliseconds, _output)) @@ -219,12 +218,13 @@ public async Task SendAsync_MultipleOutstandingSendOperations_Throws(Uri server) [OuterLoop("Uses external servers")] [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] - [ActiveIssue("https://github.com/dotnet/runtime/issues/45586", TestPlatforms.Browser)] public async Task ReceiveAsync_MultipleOutstandingReceiveOperations_Throws(Uri server) { using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(server, TimeOutMilliseconds, _output)) { - var cts = new CancellationTokenSource(TimeOutMilliseconds); + // It seems that sometimes the default timeout is not enough for browser so we will extend it + // See issue https://github.com/dotnet/runtime/issues/46909 + var cts = PlatformDetection.IsBrowser ? new CancellationTokenSource(BrowserTimeOutMilliseconds) : new CancellationTokenSource(TimeOutMilliseconds); Task[] tasks = new Task[2]; @@ -284,7 +284,6 @@ await SendAsync( [OuterLoop("Uses external servers")] [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] - [ActiveIssue("https://github.com/dotnet/runtime/issues/45586", TestPlatforms.Browser)] public async Task SendAsync_SendZeroLengthPayloadAsEndOfMessage_Success(Uri server) { using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(server, TimeOutMilliseconds, _output)) @@ -329,7 +328,10 @@ public async Task SendReceive_VaryingLengthBuffers_Success(Uri server) using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(server, TimeOutMilliseconds, _output)) { var rand = new Random(); - var ctsDefault = new CancellationTokenSource(TimeOutMilliseconds); + + // It seems that sometimes the default timeout is not enough for browser so we will extend it + // See issue https://github.com/dotnet/runtime/issues/46909 + var ctsDefault = PlatformDetection.IsBrowser ? new CancellationTokenSource(BrowserTimeOutMilliseconds) : new CancellationTokenSource(TimeOutMilliseconds); // Values chosen close to boundaries in websockets message length handling as well // as in vectors used in mask application. @@ -463,7 +465,6 @@ await LoopbackServer.CreateServerAsync(async (server, url) => [OuterLoop("Uses external servers")] [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] - [ActiveIssue("https://github.com/dotnet/runtime/issues/45586", TestPlatforms.Browser)] public async Task ZeroByteReceive_CompletesWhenDataAvailable(Uri server) { using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(server, TimeOutMilliseconds, _output))