diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs index 4d71e31863ce03..7159283212d601 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs @@ -82,17 +82,25 @@ public Task ConnectAsync(Uri uri, CancellationToken cancellationToken) return ConnectAsyncCore(uri, cancellationToken); } - private Task ConnectAsyncCore(Uri uri, CancellationToken cancellationToken) + private async Task ConnectAsyncCore(Uri uri, CancellationToken cancellationToken) { _innerWebSocket = new WebSocketHandle(); - // Change internal state to 'connected' to enable the other methods - if ((InternalState)Interlocked.CompareExchange(ref _state, (int)InternalState.Connected, (int)InternalState.Connecting) != InternalState.Connecting) + try + { + await _innerWebSocket.ConnectAsync(uri, cancellationToken, Options).ConfigureAwait(false); + } + catch { - return Task.FromException(new ObjectDisposedException(nameof(ClientWebSocket))); // Aborted/Disposed during connect. + Dispose(); + throw; } - return _innerWebSocket.ConnectAsync(uri, cancellationToken, Options); + if ((InternalState)Interlocked.CompareExchange(ref _state, (int)InternalState.Connected, (int)InternalState.Connecting) != InternalState.Connecting) + { + Debug.Assert(_state == (int)InternalState.Disposed); + throw new ObjectDisposedException(GetType().FullName); + } } public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.cs index 68faae7d7408cf..d6f9512b15ce4c 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.cs @@ -33,6 +33,12 @@ public async Task ConnectAsync_NotWebSocketServer_ThrowsWebSocketExceptionWithMe } Assert.Equal(WebSocketState.Closed, cws.State); Assert.Equal(exceptionMessage, ex.Message); + + // Other operations throw after failed connect + await Assert.ThrowsAsync(() => cws.ReceiveAsync(new byte[1], default)); + await Assert.ThrowsAsync(() => cws.SendAsync(new byte[1], WebSocketMessageType.Binary, true, default)); + await Assert.ThrowsAsync(() => cws.CloseAsync(WebSocketCloseStatus.NormalClosure, null, default)); + await Assert.ThrowsAsync(() => cws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, default)); } }