From 7f8a1e42b3f0648eabea44ea8d1b1e5aa142e98a Mon Sep 17 00:00:00 2001 From: Amy Yan Date: Fri, 28 Jun 2024 14:52:53 +1000 Subject: [PATCH] feature: isomorphic `WebSocket` support --- src/WebSocketClient.ts | 30 ++-- src/WebSocketConnection.ts | 91 +++++++++--- src/utils.ts | 6 + tests/WebSocketClient.browser.test.ts | 192 ++++++++++++++++++++++++++ 4 files changed, 285 insertions(+), 34 deletions(-) create mode 100644 tests/WebSocketClient.browser.test.ts diff --git a/src/WebSocketClient.ts b/src/WebSocketClient.ts index 8978b056..155a927d 100644 --- a/src/WebSocketClient.ts +++ b/src/WebSocketClient.ts @@ -8,7 +8,7 @@ import type { ContextTimed, ContextTimedInput } from '@matrixai/contexts'; import { AbstractEvent } from '@matrixai/events'; import { createDestroy } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; -import WebSocket from 'ws'; +import * as ws from 'ws'; import { EventAll } from '@matrixai/events'; import { context, timedCancellable } from '@matrixai/contexts/dist/decorators'; import * as errors from './errors'; @@ -62,6 +62,7 @@ class WebSocketClient { reasonToCode?: StreamReasonToCode; codeToReason?: StreamCodeToReason; logger?: Logger; + _webSocketClass?: typeof globalThis.WebSocket | typeof ws.WebSocket; }, ctx?: Partial, ): Promise; @@ -79,6 +80,9 @@ class WebSocketClient { reasonToCode, codeToReason, logger = new Logger(`${this.name}`), + _webSocketClass = globalThis.WebSocket == null + ? ws.WebSocket + : globalThis.WebSocket, }: { host: string; port: number; @@ -87,6 +91,7 @@ class WebSocketClient { reasonToCode?: StreamReasonToCode; codeToReason?: StreamCodeToReason; logger?: Logger; + _webSocketClass?: typeof globalThis.WebSocket | typeof ws.WebSocket; }, @context ctx: ContextTimed, ): Promise { @@ -106,16 +111,19 @@ class WebSocketClient { const address = `wss://${utils.buildAddress(host_, port_)}`; - // RejectUnauthorized must be false when TLSVerifyCallback exists, - // This is so that verification can be deferred to the callback rather than the system installed Certs - const webSocket = new WebSocket(address, { - rejectUnauthorized: - wsConfig.verifyPeer && wsConfig.verifyCallback == null, - key: wsConfig.key as any, - cert: wsConfig.cert as any, - ca: wsConfig.ca as any, - headers: wsConfig.headers, - }); + let webSocket: ws.WebSocket | typeof globalThis.WebSocket.prototype; + if (_webSocketClass === ws.WebSocket) { + webSocket = new ws.WebSocket(address, { + rejectUnauthorized: + wsConfig.verifyPeer && wsConfig.verifyCallback == null, + key: wsConfig.key as any, + cert: wsConfig.cert as any, + ca: wsConfig.ca as any, + headers: wsConfig.headers, + }); + } else { + webSocket = new _webSocketClass(address); + } const connectionId = 0; const connection = new WebSocketConnection({ diff --git a/src/WebSocketConnection.ts b/src/WebSocketConnection.ts index 35e81d56..322acf45 100644 --- a/src/WebSocketConnection.ts +++ b/src/WebSocketConnection.ts @@ -82,7 +82,7 @@ class WebSocketConnection { * Internal native WebSocket object. * @internal */ - protected socket: ws.WebSocket; + protected socket: ws.WebSocket | typeof globalThis.WebSocket.prototype; protected config: WebSocketConfig; @@ -263,6 +263,12 @@ class WebSocketConnection { this.streamMap.delete(stream.streamId); }; + protected handleBrowserSocketMessage = async ( + event: MessageEvent, + ) => { + return this.handleSocketMessage(event.data, true); + }; + protected handleSocketMessage = async ( data: ws.RawData, isBinary: boolean, @@ -433,7 +439,7 @@ class WebSocketConnection { ); }; - protected handleSocketError = (err: Error) => { + protected handleSocketError = (err: any) => { const errorCode = utils.ConnectionErrorCode.InternalServerError; const reason = 'An error occurred on the underlying WebSocket instance'; this.closeSocket(errorCode, reason); @@ -518,7 +524,7 @@ class WebSocketConnection { connectionId: number; meta?: undefined; config: WebSocketConfig; - socket: ws.WebSocket; + socket: ws.WebSocket | typeof globalThis.WebSocket.prototype; reasonToCode?: StreamReasonToCode; codeToReason?: StreamCodeToReason; logger?: Logger; @@ -528,13 +534,14 @@ class WebSocketConnection { connectionId: number; meta: ConnectionMetadata; config: WebSocketConfig; - socket: ws.WebSocket; + socket: ws.WebSocket | typeof globalThis.WebSocket.prototype; reasonToCode?: StreamReasonToCode; codeToReason?: StreamCodeToReason; logger?: Logger; }) { this.logger = logger ?? new Logger(`${this.constructor.name}`); this.connectionId = connectionId; + socket.binaryType = 'arraybuffer'; this.socket = socket; this.config = config; this.type = type; @@ -706,15 +713,27 @@ class WebSocketConnection { }), ); }; - this.socket.once('error', openErrorHandler); const openHandler = () => { this.resolveSecureEstablishedP(); }; - this.socket.once('open', openHandler); - // This will always happen, no need to remove the handler - this.socket.once('close', this.handleSocketClose); + if (utils.isNodeWebsocket(this.socket)) { + this.socket.once('error', openErrorHandler); + this.socket.once('open', openHandler); + // This will always happen, no need to remove the handler + this.socket.once('close', this.handleSocketClose); + } else { + this.socket.addEventListener('error', openErrorHandler, { once: true }); + this.socket.addEventListener('open', openHandler, { once: true }); + // This will always happen, no need to remove the handler + this.socket.addEventListener( + 'close', + (event) => + this.handleSocketClose(event.code, Buffer.from(event.reason)), + { once: true }, + ); + } - if (this.type === 'client') { + if (this.type === 'client' && utils.isNodeWebsocket(this.socket)) { this.socket.once('upgrade', async (request) => { const tlsSocket = request.socket as TLSSocket; const peerCert = tlsSocket.getPeerCertificate(true); @@ -788,23 +807,39 @@ class WebSocketConnection { ); } - this.socket.off('open', openHandler); - // Upgrade only exists on the ws library, we can use removeAllListeners without worrying - this.socket.removeAllListeners('upgrade'); + if (utils.isNodeWebsocket(this.socket)) { + this.socket.off('open', openHandler); + // Upgrade only exists on the ws library, we can use removeAllListeners without worrying + this.socket.removeAllListeners('upgrade'); + } else { + this.socket.removeEventListener('open', openHandler); + } + // Close the ws if it's open at this stage await this.closedP; throw e; } finally { ctx.signal.removeEventListener('abort', abortHandler); - // Upgrade has already been removed by being called once or by the catch - this.socket.off('error', openErrorHandler); + if (utils.isNodeWebsocket(this.socket)) { + // Upgrade has already been removed by being called once or by the catch + this.socket.off('error', openErrorHandler); + } else { + this.socket.removeEventListener('error', openErrorHandler); + } } // Set the connection up - this.socket.on('message', this.handleSocketMessage); - this.socket.on('ping', this.handleSocketPing); - this.socket.on('pong', this.handleSocketPong); - this.socket.once('error', this.handleSocketError); + if (utils.isNodeWebsocket(this.socket)) { + this.socket.on('message', this.handleSocketMessage); + this.socket.on('ping', this.handleSocketPing); + this.socket.on('pong', this.handleSocketPong); + this.socket.once('error', this.handleSocketError); + } else { + this.socket.addEventListener('message', this.handleBrowserSocketMessage); + this.socket.addEventListener('error', this.handleSocketError, { + once: true, + }); + } if (this.config.keepAliveIntervalTime != null) { this.startKeepAliveIntervalTimer(this.config.keepAliveIntervalTime); @@ -996,10 +1031,18 @@ class WebSocketConnection { events.EventWebSocketConnectionClose.name, this.handleEventWebSocketConnectionClose, ); - this.socket.off('message', this.handleSocketMessage); - this.socket.off('ping', this.handleSocketPing); - this.socket.off('pong', this.handleSocketPong); - this.socket.off('error', this.handleSocketError); + if (utils.isNodeWebsocket(this.socket)) { + this.socket.off('message', this.handleSocketMessage); + this.socket.off('ping', this.handleSocketPing); + this.socket.off('pong', this.handleSocketPong); + this.socket.off('error', this.handleSocketError); + } else { + this.socket.removeEventListener( + 'message', + this.handleBrowserSocketMessage, + ); + this.socket.removeEventListener('error', this.handleSocketError); + } this.logger.info(`Stopped ${this.constructor.name}`); } @@ -1049,7 +1092,9 @@ class WebSocketConnection { protected startKeepAliveIntervalTimer(ms: number): void { const keepAliveHandler = async (signal: AbortSignal) => { if (signal.aborted) return; - this.socket.ping(); + if (utils.isNodeWebsocket(this.socket)) { + this.socket.ping(); + } this.keepAliveIntervalTimer = new Timer({ delay: ms, handler: keepAliveHandler, diff --git a/src/utils.ts b/src/utils.ts index f62f6f78..9924a952 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -2,6 +2,7 @@ import type { Callback, Host, Port, PromiseDeconstructed } from './types'; import type { DetailedPeerCertificate } from 'tls'; import * as dns from 'dns'; import { IPv4, IPv6, Validator } from 'ip-num'; +import * as ws from 'ws'; import * as errors from './errors'; const textEncoder = new TextEncoder(); @@ -11,6 +12,10 @@ function never(message?: string): never { throw new errors.ErrorWebSocketUndefinedBehaviour(message); } +function isNodeWebsocket(websocket: any): websocket is ws.WebSocket { + return websocket.constructor === ws.WebSocket; +} + /** * Is it an IPv4 address? */ @@ -467,6 +472,7 @@ export { textEncoder, textDecoder, never, + isNodeWebsocket, isIPv4, isIPv6, isIPv4MappedIPv6, diff --git a/tests/WebSocketClient.browser.test.ts b/tests/WebSocketClient.browser.test.ts new file mode 100644 index 00000000..336f1423 --- /dev/null +++ b/tests/WebSocketClient.browser.test.ts @@ -0,0 +1,192 @@ +import type { KeyTypes } from './utils'; +import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; +import * as ws from 'ws'; +import { promise } from '@/utils'; +import * as events from '@/events'; +import * as errors from '@/errors'; +import WebSocketClient from '@/WebSocketClient'; +import WebSocketServer from '@/WebSocketServer'; +import * as testsUtils from './utils'; + +describe(`${WebSocketClient.name} browser`, () => { + const logger = new Logger(`${WebSocketClient.name} Test`, LogLevel.WARN, [ + new StreamHandler( + formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, + ), + ]); + const localhost = '127.0.0.1'; + const types: Array = ['RSA', 'ECDSA', 'ED25519']; + // Const types: Array = ['RSA']; + const defaultType = types[0]; + + class BrowserWebSocket extends ws.WebSocket { + constructor(address: string, protocols: any) { + super(address, protocols, { + rejectUnauthorized: false, + }); + } + } + + test('to ipv6 server succeeds', async () => { + const connectionEventProm = + promise(); + const tlsConfigServer = await testsUtils.generateConfig(defaultType); + const server = new WebSocketServer({ + logger: logger.getChild(WebSocketServer.name), + config: { + key: tlsConfigServer.key, + cert: tlsConfigServer.cert, + verifyPeer: false, + }, + }); + server.addEventListener( + events.EventWebSocketServerConnection.name, + (e: events.EventWebSocketServerConnection) => + connectionEventProm.resolveP(e), + ); + await server.start({ + host: '::1', + port: 0, + }); + const client = await WebSocketClient.createWebSocketClient({ + host: '::1', + port: server.port, + logger: logger.getChild(WebSocketClient.name), + config: { + verifyPeer: false, + }, + _webSocketClass: BrowserWebSocket as any, + }); + const conn = (await connectionEventProm.p).detail; + expect(conn.localHost).toBe('::1'); + expect(conn.localPort).toBe(server.port); + expect(conn.remoteHost).toBe('::1'); + await client.destroy(); + await server.stop(); + }); + test('to dual stack server succeeds', async () => { + const connectionEventProm = + promise(); + const tlsConfigServer = await testsUtils.generateConfig(defaultType); + const server = new WebSocketServer({ + logger: logger.getChild(WebSocketServer.name), + config: { + key: tlsConfigServer.key, + cert: tlsConfigServer.cert, + verifyPeer: false, + }, + }); + server.addEventListener( + events.EventWebSocketServerConnection.name, + (e: events.EventWebSocketServerConnection) => + connectionEventProm.resolveP(e), + ); + await server.start({ + host: '::', + port: 0, + }); + const client = await WebSocketClient.createWebSocketClient({ + host: '::', // Will resolve to ::1 + port: server.port, + logger: logger.getChild(WebSocketClient.name), + config: { + verifyPeer: false, + }, + _webSocketClass: BrowserWebSocket as any, + }); + const conn = (await connectionEventProm.p).detail; + expect(conn.localHost).toBe('::1'); + expect(conn.localPort).toBe(server.port); + expect(conn.remoteHost).toBe('::1'); + await client.destroy(); + await server.stop(); + }); + describe('hard connection failures', () => { + test('internal error when there is no server', async () => { + // WebSocketClient repeatedly dials until the connection timeout + await expect( + WebSocketClient.createWebSocketClient({ + host: localhost, + port: 56666, + logger: logger.getChild(WebSocketClient.name), + config: { + keepAliveTimeoutTime: 200, + verifyPeer: false, + }, + _webSocketClass: BrowserWebSocket as any, + }), + ).rejects.toHaveProperty( + ['name'], + errors.ErrorWebSocketConnectionLocal.name, + ); + }); + test('client times out with ctx timer while starting', async () => { + const tlsConfigServer = await testsUtils.generateConfig(defaultType); + const server = new WebSocketServer({ + logger: logger.getChild(WebSocketServer.name), + config: { + key: tlsConfigServer.key, + cert: tlsConfigServer.cert, + verifyPeer: true, + verifyCallback: async () => { + await testsUtils.sleep(1000); + }, + }, + }); + await server.start({ + host: localhost, + port: 0, + }); + await expect( + WebSocketClient.createWebSocketClient( + { + host: localhost, + port: server.port, + logger: logger.getChild(WebSocketClient.name), + config: { + verifyPeer: false, + }, + _webSocketClass: BrowserWebSocket as any, + }, + { timer: 100 }, + ), + ).rejects.toThrow(errors.ErrorWebSocketClientCreateTimeOut); + await server.stop(); + }); + test('client times out with ctx signal while starting', async () => { + const abortController = new AbortController(); + const tlsConfigServer = await testsUtils.generateConfig(defaultType); + const server = new WebSocketServer({ + logger: logger.getChild(WebSocketServer.name), + config: { + key: tlsConfigServer.key, + cert: tlsConfigServer.cert, + verifyPeer: true, + verifyCallback: async () => { + await testsUtils.sleep(1000); + }, + }, + }); + await server.start({ + host: localhost, + port: 0, + }); + const clientProm = WebSocketClient.createWebSocketClient( + { + host: localhost, + port: server.port, + logger: logger.getChild(WebSocketClient.name), + config: { + verifyPeer: false, + }, + _webSocketClass: BrowserWebSocket as any, + }, + { signal: abortController.signal }, + ); + await testsUtils.sleep(100); + abortController.abort(Error('abort error')); + await expect(clientProm).rejects.toThrow(Error('abort error')); + await server.stop(); + }); + }); +});