From 15eed2e833846f16e286def1f7fc8691515aa4b9 Mon Sep 17 00:00:00 2001 From: Katya Sokolova Date: Thu, 8 Dec 2022 20:19:56 +0100 Subject: [PATCH 1/5] Fix compression --- .../System/Net/WebSockets/ClientWebSocket.cs | 3 ++ .../tests/DeflateTests.cs | 47 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs index de3cde635d9ad3..5780555470c1e9 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs @@ -148,6 +148,9 @@ public override Task SendAsync(ArraySegment buffer, WebSocketMessageType m public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => ConnectedWebSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) => + ConnectedWebSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken); + public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) => ConnectedWebSocket.ReceiveAsync(buffer, cancellationToken); diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs index 92f287fe4506ff..32d3c863465542 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -86,6 +86,53 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => }), new LoopbackServer.Options { WebSocketEndpoint = true }); } + [ConditionalFact(nameof(WebSocketsSupported))] + public async Task ThrowWhenContinuationWithDifferentCompressionFlags() + { + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using (var cws = new ClientWebSocket()) + using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) + { + cws.Options.DangerousDeflateOptions = new WebSocketDeflateOptions(); + await ConnectAsync(cws, uri, cts.Token); + + await cws.SendAsync(Memory.Empty, WebSocketMessageType.Text, WebSocketMessageFlags.DisableCompression, default); + Assert.Throws("messageFlags", () => + cws.SendAsync(Memory.Empty, WebSocketMessageType.Binary, WebSocketMessageFlags.EndOfMessage, default)); + } + }, server => server.AcceptConnectionAsync(async connection => + { + Dictionary headers = await LoopbackHelper.WebSocketHandshakeAsync(connection); + }), new LoopbackServer.Options { WebSocketEndpoint = true }); + } + + [ConditionalFact(nameof(WebSocketsSupported))] + public async Task SendHelloWithDisableCompression() + { + byte[] bytes = "Hello"u8.ToArray(); + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using (var cws = new ClientWebSocket()) + using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) + { + cws.Options.DangerousDeflateOptions = new WebSocketDeflateOptions(); + await ConnectAsync(cws, uri, cts.Token); + + WebSocketMessageFlags flags = WebSocketMessageFlags.DisableCompression | WebSocketMessageFlags.EndOfMessage; + await cws.SendAsync(bytes, WebSocketMessageType.Text, flags, cts.Token); + } + }, server => server.AcceptConnectionAsync(async connection => + { + var buffer = new byte[bytes.Length]; + Dictionary headers = await LoopbackHelper.WebSocketHandshakeAsync(connection); + using WebSocket websocket = WebSocket.CreateFromStream(connection.Stream, true, null, TimeSpan.FromSeconds(30)); + Assert.True(websocket.State == WebSocketState.Open || websocket.State == WebSocketState.CloseSent); + await websocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + Assert.Equal(bytes, buffer); + }), new LoopbackServer.Options { WebSocketEndpoint = true }); + } + private static string CreateDeflateOptionsHeader(WebSocketDeflateOptions options) { var builder = new StringBuilder(); From f40af7fc1645f40a9295af6510672c3ba3030071 Mon Sep 17 00:00:00 2001 From: Katya Sokolova Date: Fri, 9 Dec 2022 13:21:09 +0100 Subject: [PATCH 2/5] Apply suggestions from code review Co-authored-by: Miha Zupan --- .../tests/DeflateTests.cs | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs index 32d3c863465542..c22afa5d5f21a4 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -87,23 +87,22 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => } [ConditionalFact(nameof(WebSocketsSupported))] - public async Task ThrowWhenContinuationWithDifferentCompressionFlags() + public async Task ThrowsWhenContinuationHasDifferentCompressionFlags() { await LoopbackServer.CreateClientAndServerAsync(async uri => { - using (var cws = new ClientWebSocket()) - using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) - { - cws.Options.DangerousDeflateOptions = new WebSocketDeflateOptions(); - await ConnectAsync(cws, uri, cts.Token); + using var cws = new ClientWebSocket(); + using var cts = new CancellationTokenSource(TimeOutMilliseconds); + + cws.Options.DangerousDeflateOptions = new WebSocketDeflateOptions(); + await ConnectAsync(cws, uri, cts.Token); - await cws.SendAsync(Memory.Empty, WebSocketMessageType.Text, WebSocketMessageFlags.DisableCompression, default); - Assert.Throws("messageFlags", () => - cws.SendAsync(Memory.Empty, WebSocketMessageType.Binary, WebSocketMessageFlags.EndOfMessage, default)); - } + await cws.SendAsync(Memory.Empty, WebSocketMessageType.Text, WebSocketMessageFlags.DisableCompression, default); + Assert.Throws("messageFlags", () => + cws.SendAsync(Memory.Empty, WebSocketMessageType.Binary, WebSocketMessageFlags.EndOfMessage, default)); }, server => server.AcceptConnectionAsync(async connection => { - Dictionary headers = await LoopbackHelper.WebSocketHandshakeAsync(connection); + await LoopbackHelper.WebSocketHandshakeAsync(connection); }), new LoopbackServer.Options { WebSocketEndpoint = true }); } @@ -113,19 +112,18 @@ public async Task SendHelloWithDisableCompression() byte[] bytes = "Hello"u8.ToArray(); await LoopbackServer.CreateClientAndServerAsync(async uri => { - using (var cws = new ClientWebSocket()) - using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) - { - cws.Options.DangerousDeflateOptions = new WebSocketDeflateOptions(); - await ConnectAsync(cws, uri, cts.Token); + using var cws = new ClientWebSocket(); + using var cts = new CancellationTokenSource(TimeOutMilliseconds); + + cws.Options.DangerousDeflateOptions = new WebSocketDeflateOptions(); + await ConnectAsync(cws, uri, cts.Token); - WebSocketMessageFlags flags = WebSocketMessageFlags.DisableCompression | WebSocketMessageFlags.EndOfMessage; - await cws.SendAsync(bytes, WebSocketMessageType.Text, flags, cts.Token); - } + WebSocketMessageFlags flags = WebSocketMessageFlags.DisableCompression | WebSocketMessageFlags.EndOfMessage; + await cws.SendAsync(bytes, WebSocketMessageType.Text, flags, cts.Token); }, server => server.AcceptConnectionAsync(async connection => { var buffer = new byte[bytes.Length]; - Dictionary headers = await LoopbackHelper.WebSocketHandshakeAsync(connection); + await LoopbackHelper.WebSocketHandshakeAsync(connection); using WebSocket websocket = WebSocket.CreateFromStream(connection.Stream, true, null, TimeSpan.FromSeconds(30)); Assert.True(websocket.State == WebSocketState.Open || websocket.State == WebSocketState.CloseSent); await websocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); From 95e8abdf0b0d7d81fba8e418c7ff9d5422212864 Mon Sep 17 00:00:00 2001 From: Katya Sokolova Date: Fri, 9 Dec 2022 13:24:30 +0100 Subject: [PATCH 3/5] Adding SendAsync to ref --- .../ref/System.Net.WebSockets.Client.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs index d7b7a97d724950..edb4eb043bcb0b 100644 --- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs +++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs @@ -26,6 +26,7 @@ public override void Dispose() { } public override System.Threading.Tasks.ValueTask ReceiveAsync(System.Memory buffer, System.Threading.CancellationToken cancellationToken) { throw null; } public override System.Threading.Tasks.Task SendAsync(System.ArraySegment buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken) { throw null; } public override System.Threading.Tasks.ValueTask SendAsync(System.ReadOnlyMemory buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken) { throw null; } + public override System.Threading.Tasks.ValueTask SendAsync(System.ReadOnlyMemory buffer, System.Net.WebSockets.WebSocketMessageType messageType, System.Net.WebSockets.WebSocketMessageFlags messageFlags, System.Threading.CancellationToken cancellationToken) { throw null; } } public sealed partial class ClientWebSocketOptions { From c06ae140b783697cf3cf152e88b71407e6c2fcf0 Mon Sep 17 00:00:00 2001 From: Katya Sokolova Date: Fri, 9 Dec 2022 15:02:07 +0100 Subject: [PATCH 4/5] fix ws deflate tests --- .../tests/DeflateTests.cs | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs index c22afa5d5f21a4..dce6d6da0d3c56 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -89,20 +89,29 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => [ConditionalFact(nameof(WebSocketsSupported))] public async Task ThrowsWhenContinuationHasDifferentCompressionFlags() { + var deflateOpt = new WebSocketDeflateOptions + { + ClientMaxWindowBits = 14, + ClientContextTakeover = true, + ServerMaxWindowBits = 14, + ServerContextTakeover = true + }; await LoopbackServer.CreateClientAndServerAsync(async uri => { using var cws = new ClientWebSocket(); using var cts = new CancellationTokenSource(TimeOutMilliseconds); - cws.Options.DangerousDeflateOptions = new WebSocketDeflateOptions(); + cws.Options.DangerousDeflateOptions = deflateOpt; await ConnectAsync(cws, uri, cts.Token); + await cws.SendAsync(Memory.Empty, WebSocketMessageType.Text, WebSocketMessageFlags.DisableCompression, default); Assert.Throws("messageFlags", () => cws.SendAsync(Memory.Empty, WebSocketMessageType.Binary, WebSocketMessageFlags.EndOfMessage, default)); }, server => server.AcceptConnectionAsync(async connection => { - await LoopbackHelper.WebSocketHandshakeAsync(connection); + var extensionsReply = CreateDeflateOptionsHeader(deflateOpt); + await LoopbackHelper.WebSocketHandshakeAsync(connection, extensionsReply); }), new LoopbackServer.Options { WebSocketEndpoint = true }); } @@ -110,22 +119,43 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => public async Task SendHelloWithDisableCompression() { byte[] bytes = "Hello"u8.ToArray(); + byte[] compressed = new byte[] { 0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00 }; + var deflateOpt = new WebSocketDeflateOptions + { + ClientMaxWindowBits = 14, + ClientContextTakeover = true, + ServerMaxWindowBits = 14, + ServerContextTakeover = true + }; + await LoopbackServer.CreateClientAndServerAsync(async uri => { using var cws = new ClientWebSocket(); using var cts = new CancellationTokenSource(TimeOutMilliseconds); - cws.Options.DangerousDeflateOptions = new WebSocketDeflateOptions(); + cws.Options.DangerousDeflateOptions = deflateOpt; await ConnectAsync(cws, uri, cts.Token); + await cws.SendAsync(bytes, WebSocketMessageType.Text, true, cts.Token); + WebSocketMessageFlags flags = WebSocketMessageFlags.DisableCompression | WebSocketMessageFlags.EndOfMessage; await cws.SendAsync(bytes, WebSocketMessageType.Text, flags, cts.Token); }, server => server.AcceptConnectionAsync(async connection => { var buffer = new byte[bytes.Length]; - await LoopbackHelper.WebSocketHandshakeAsync(connection); - using WebSocket websocket = WebSocket.CreateFromStream(connection.Stream, true, null, TimeSpan.FromSeconds(30)); + var extensionsReply = CreateDeflateOptionsHeader(deflateOpt); + await LoopbackHelper.WebSocketHandshakeAsync(connection, extensionsReply); + using WebSocket websocket = WebSocket.CreateFromStream(connection.Stream, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = deflateOpt + }); + Assert.True(websocket.State == WebSocketState.Open || websocket.State == WebSocketState.CloseSent); + + await websocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + Assert.Equal(bytes, buffer); + await websocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); Assert.Equal(bytes, buffer); }), new LoopbackServer.Options { WebSocketEndpoint = true }); From af9943eaff941591c34774523fd44319c4b65c09 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Mon, 12 Dec 2022 13:26:10 +0000 Subject: [PATCH 5/5] Check bytes on server side --- .../tests/DeflateTests.cs | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs index dce6d6da0d3c56..1fd9b21d838141 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -119,7 +119,13 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => public async Task SendHelloWithDisableCompression() { byte[] bytes = "Hello"u8.ToArray(); - byte[] compressed = new byte[] { 0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00 }; + + int prefixLength = 2; + byte[] rawPrefix = new byte[] { 0x81, 0x85 }; // fin=1, rsv=0, opcode=text; mask=1, len=5 + int rawRemainingBytes = 9; // mask bytes (4) + payload bytes (5) + byte[] compressedPrefix = new byte[] { 0xc1, 0x87 }; // fin=1, rsv=compressed, opcode=text; mask=1, len=7 + int compressedRemainingBytes = 11; // mask bytes (4) + payload bytes (7) + var deflateOpt = new WebSocketDeflateOptions { ClientMaxWindowBits = 14, @@ -142,22 +148,26 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => await cws.SendAsync(bytes, WebSocketMessageType.Text, flags, cts.Token); }, server => server.AcceptConnectionAsync(async connection => { - var buffer = new byte[bytes.Length]; + var buffer = new byte[compressedRemainingBytes]; var extensionsReply = CreateDeflateOptionsHeader(deflateOpt); await LoopbackHelper.WebSocketHandshakeAsync(connection, extensionsReply); - using WebSocket websocket = WebSocket.CreateFromStream(connection.Stream, new WebSocketCreationOptions - { - IsServer = true, - DangerousDeflateOptions = deflateOpt - }); - Assert.True(websocket.State == WebSocketState.Open || websocket.State == WebSocketState.CloseSent); + // first message is compressed + await ReadExactAsync(buffer, prefixLength); + Assert.Equal(compressedPrefix, buffer[..prefixLength]); + // read rest of the frame + await ReadExactAsync(buffer, compressedRemainingBytes); - await websocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); - Assert.Equal(bytes, buffer); + // second message is not compressed + await ReadExactAsync(buffer, prefixLength); + Assert.Equal(rawPrefix, buffer[..prefixLength]); + // read rest of the frame + await ReadExactAsync(buffer, rawRemainingBytes); - await websocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); - Assert.Equal(bytes, buffer); + async Task ReadExactAsync(byte[] buf, int n) + { + await connection.Stream.ReadAtLeastAsync(buf.AsMemory(0, n), n); + } }), new LoopbackServer.Options { WebSocketEndpoint = true }); }