From 2f2557eaf8ba1bd073a0783a6f3644b5a90de815 Mon Sep 17 00:00:00 2001 From: Amy Yan Date: Mon, 23 Oct 2023 11:25:24 +1100 Subject: [PATCH 1/3] feat: headers are now passed to `TLSVerifyCallback` --- src/WebSocketConnection.ts | 2 +- src/WebSocketServer.ts | 2 +- src/types.ts | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/WebSocketConnection.ts b/src/WebSocketConnection.ts index fe49427c..290cebf6 100644 --- a/src/WebSocketConnection.ts +++ b/src/WebSocketConnection.ts @@ -725,7 +725,7 @@ class WebSocketConnection { const ca = utils.collectPEMs(this.config.ca).map(utils.pemToDER); try { if (this.config.verifyPeer && this.config.verifyCallback != null) { - await this.config.verifyCallback?.(peerCertChain, ca); + await this.config.verifyCallback?.(peerCertChain, ca, request.headers); } this._localHost = request.connection.localAddress as Host; this._localPort = request.connection.localPort as Port; diff --git a/src/WebSocketServer.ts b/src/WebSocketServer.ts index 973ea8c7..6eba2291 100644 --- a/src/WebSocketServer.ts +++ b/src/WebSocketServer.ts @@ -379,7 +379,7 @@ class WebSocketServer { const peerCertChain = utils.toPeerCertChain(peerCert); const ca = utils.collectPEMs(this.config.ca).map(utils.pemToDER); try { - await this.config.verifyCallback(peerCertChain, ca); + await this.config.verifyCallback(peerCertChain, ca, info.req.headers); return done(true); } catch (e) { info.req.destroy(e); diff --git a/src/types.ts b/src/types.ts index 5d6a7833..8ba931ec 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,3 +1,5 @@ +import type { IncomingHttpHeaders } from "http"; + // Async /** @@ -79,6 +81,7 @@ type ConnectionMetadata = { type TLSVerifyCallback = ( certs: Array, ca: Array, + headers: IncomingHttpHeaders ) => PromiseLike; type WebSocketConfig = { From 6daf94665fe30e2d18c31aafa3167e1fd4bab9ca Mon Sep 17 00:00:00 2001 From: Amy Yan Date: Thu, 23 Nov 2023 15:40:02 +1100 Subject: [PATCH 2/3] feat: added ability to insert headers into `config` paramater of `WebSocketClient` and `WebSocketServer` --- src/WebSocketClient.ts | 1 + src/WebSocketConnection.ts | 6 +- src/WebSocketServer.ts | 30 +++++- src/types.ts | 10 +- tests/WebSocketClient.test.ts | 189 ++++++++++++++++++++++++++++++++++ 5 files changed, 232 insertions(+), 4 deletions(-) diff --git a/src/WebSocketClient.ts b/src/WebSocketClient.ts index 464aa6ee..8978b056 100644 --- a/src/WebSocketClient.ts +++ b/src/WebSocketClient.ts @@ -114,6 +114,7 @@ class WebSocketClient { key: wsConfig.key as any, cert: wsConfig.cert as any, ca: wsConfig.ca as any, + headers: wsConfig.headers, }); const connectionId = 0; diff --git a/src/WebSocketConnection.ts b/src/WebSocketConnection.ts index 290cebf6..35e81d56 100644 --- a/src/WebSocketConnection.ts +++ b/src/WebSocketConnection.ts @@ -725,7 +725,11 @@ class WebSocketConnection { const ca = utils.collectPEMs(this.config.ca).map(utils.pemToDER); try { if (this.config.verifyPeer && this.config.verifyCallback != null) { - await this.config.verifyCallback?.(peerCertChain, ca, request.headers); + await this.config.verifyCallback?.( + peerCertChain, + ca, + request.headers, + ); } this._localHost = request.connection.localAddress as Host; this._localPort = request.connection.localPort as Port; diff --git a/src/WebSocketServer.ts b/src/WebSocketServer.ts index 6eba2291..e9ff411b 100644 --- a/src/WebSocketServer.ts +++ b/src/WebSocketServer.ts @@ -243,6 +243,28 @@ class WebSocketServer { ); }; + protected handleServerHeaders = (headers: Array) => { + if (this.config.headers == null) { + return; + } + const configHeaders = { ...this.config.headers }; + for (let i = 0; i < headers.length; i++) { + const headerKV = headers[i].split(': ', 2); + if (headerKV.length !== 2) { + continue; + } + const [headerName] = headerKV; + const lowercaseHeaderName = headerName.toLowerCase(); + if (lowercaseHeaderName in configHeaders) { + headers[i] = `${headerName}: ${configHeaders[lowercaseHeaderName]}`; + delete configHeaders[lowercaseHeaderName]; + } + } + for (const [header, value] of Object.entries(configHeaders)) { + headers.push(`${header}: ${value}`); + } + }; + /** * WebSocketServer.constructor * @@ -379,7 +401,11 @@ class WebSocketServer { const peerCertChain = utils.toPeerCertChain(peerCert); const ca = utils.collectPEMs(this.config.ca).map(utils.pemToDER); try { - await this.config.verifyCallback(peerCertChain, ca, info.req.headers); + await this.config.verifyCallback( + peerCertChain, + ca, + info.req.headers, + ); return done(true); } catch (e) { info.req.destroy(e); @@ -391,6 +417,7 @@ class WebSocketServer { }); this.webSocketServer.on('connection', this.handleServerConnection); + this.webSocketServer.on('headers', this.handleServerHeaders); this.webSocketServer.on('close', this.handleWebSocketServerClosed); this.server.on('close', this.handleServerClosed); this.webSocketServer.on('error', this.handleServerError); @@ -485,6 +512,7 @@ class WebSocketServer { ); this.webSocketServer.off('connection', this.handleServerConnection); + this.webSocketServer.off('headers', this.handleServerHeaders); this.webSocketServer.off('close', this.handleServerClosed); this.server.off('close', this.handleServerClosed); this.webSocketServer.off('error', this.handleServerError); diff --git a/src/types.ts b/src/types.ts index 8ba931ec..d5cbd3fc 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,4 +1,4 @@ -import type { IncomingHttpHeaders } from "http"; +import type { IncomingHttpHeaders } from 'http'; // Async @@ -81,7 +81,7 @@ type ConnectionMetadata = { type TLSVerifyCallback = ( certs: Array, ca: Array, - headers: IncomingHttpHeaders + headers: IncomingHttpHeaders, ) => PromiseLike; type WebSocketConfig = { @@ -119,6 +119,12 @@ type WebSocketConfig = { */ cert?: string | Array | Uint8Array | Array; + /** + * Headers that will be attached to the HTTP Upgrade Request/Response. + * This can be used for authentication purposes, or to provide metadata regarding the connection. + */ + headers?: IncomingHttpHeaders; + /** * Verify the other peer. * Clients by default set this to true. diff --git a/tests/WebSocketClient.test.ts b/tests/WebSocketClient.test.ts index 78878767..17498eb7 100644 --- a/tests/WebSocketClient.test.ts +++ b/tests/WebSocketClient.test.ts @@ -463,6 +463,195 @@ describe(WebSocketClient.name, () => { await server.stop(); }); }); + describe('custom TLS verification with headers', () => { + test('server succeeds custom verification', async () => { + const tlsConfigs = await testsUtils.generateConfig('RSA'); + const authorization = 'password'; + const server = new WebSocketServer({ + logger: logger.getChild(WebSocketServer.name), + config: { + key: tlsConfigs.key, + cert: tlsConfigs.cert, + verifyPeer: false, + headers: { + authorization, + }, + }, + }); + const handleConnectionEventProm = promise(); + server.addEventListener( + events.EventWebSocketServerConnection.name, + handleConnectionEventProm.resolveP, + ); + await server.start({ + host: localhost, + }); + // Connection should succeed + const verifyProm = promise(); + const client = await WebSocketClient.createWebSocketClient({ + host: localhost, + port: server.port, + logger: logger.getChild(WebSocketClient.name), + config: { + verifyPeer: true, + verifyCallback: async (_certs, _ca, headers) => { + verifyProm.resolveP(headers.authorization); + }, + }, + }); + await handleConnectionEventProm.p; + await expect(verifyProm.p).resolves.toBe(authorization); + await client.destroy(); + await server.stop(); + }); + test('server fails custom verification', async () => { + const tlsConfigs = await testsUtils.generateConfig('RSA'); + const authorization = 'password'; + const server = new WebSocketServer({ + logger: logger.getChild(WebSocketServer.name), + config: { + key: tlsConfigs.key, + cert: tlsConfigs.cert, + verifyPeer: false, + headers: { + authorization, + }, + }, + }); + const handleConnectionEventProm = promise(); + server.addEventListener( + events.EventWebSocketServerConnection.name, + (event: events.EventWebSocketServerConnection) => + handleConnectionEventProm.resolveP(event.detail), + ); + await server.start({ + host: localhost, + }); + // Connection should fail + const clientProm = WebSocketClient.createWebSocketClient({ + host: localhost, + port: server.port, + logger: logger.getChild(WebSocketClient.name), + config: { + verifyPeer: true, + verifyCallback: () => { + throw Error('SOME ERROR'); + }, + }, + }); + clientProm.catch(() => {}); + + // Verification by peer happens after connection is securely established and started + const serverConn = await handleConnectionEventProm.p; + const serverErrorProm = promise(); + serverConn.addEventListener( + events.EventWebSocketConnectionError.name, + (evt: events.EventWebSocketConnectionError) => + serverErrorProm.rejectP(evt.detail), + ); + await expect(serverErrorProm.p).rejects.toThrow( + errors.ErrorWebSocketConnectionPeer, + ); + await expect(clientProm).rejects.toThrow( + errors.ErrorWebSocketConnectionLocal, + ); + + await server.stop(); + }); + test('client succeeds custom verification', async () => { + const tlsConfigs = await testsUtils.generateConfig('RSA'); + const authorization = 'password'; + const verifyProm = promise(); + const server = new WebSocketServer({ + logger: logger.getChild(WebSocketServer.name), + config: { + key: tlsConfigs.key, + cert: tlsConfigs.cert, + verifyPeer: true, + verifyCallback: async (_certs, _ca, headers) => { + verifyProm.resolveP(headers.authorization); + }, + }, + }); + const handleConnectionEventProm = promise(); + server.addEventListener( + events.EventWebSocketServerConnection.name, + handleConnectionEventProm.resolveP, + ); + await server.start({ + host: localhost, + }); + // Connection should succeed + const client = await WebSocketClient.createWebSocketClient({ + host: localhost, + port: server.port, + logger: logger.getChild(WebSocketClient.name), + config: { + verifyPeer: false, + key: tlsConfigs.key, + cert: tlsConfigs.cert, + headers: { + authorization, + }, + }, + }); + await handleConnectionEventProm.p; + await expect(verifyProm.p).resolves.toBe(authorization); + await client.destroy(); + await server.stop(); + }); + test('client fails custom verification', async () => { + const tlsConfigs = await testsUtils.generateConfig('RSA'); + const authorization = 'password'; + const server = new WebSocketServer({ + logger: logger.getChild(WebSocketServer.name), + config: { + key: tlsConfigs.key, + cert: tlsConfigs.cert, + verifyPeer: true, + verifyCallback: () => { + throw Error('SOME ERROR'); + }, + }, + }); + const handleConnectionEventProm = promise(); + server.addEventListener( + events.EventWebSocketServerConnection.name, + (event: events.EventWebSocketServerConnection) => + handleConnectionEventProm.resolveP(event.detail), + ); + await server.start({ + host: localhost, + }); + // Connection should fail + await expect( + WebSocketClient.createWebSocketClient({ + host: localhost, + port: server.port, + logger: logger.getChild(WebSocketClient.name), + config: { + key: tlsConfigs.key, + cert: tlsConfigs.cert, + verifyPeer: false, + headers: { + authorization, + }, + }, + }), + ).rejects.toHaveProperty('name', 'ErrorWebSocketConnectionPeer'); + + // // Server connection is never emitted + await Promise.race([ + handleConnectionEventProm.p.then(() => { + throw Error('Server connection should not be emitted'); + }), + // Allow some time + testsUtils.sleep(200), + ]); + + await server.stop(); + }); + }); describe.each(types)('custom TLS verification with %s', (type) => { test('server succeeds custom verification', async () => { const tlsConfigs = await testsUtils.generateConfig(type); From 1bf085312cc2df69ac25ee3e10b47bdbea06a0c0 Mon Sep 17 00:00:00 2001 From: Amy Yan Date: Thu, 23 Nov 2023 15:44:41 +1100 Subject: [PATCH 3/3] fix: `config.header` key casing is now disregarded --- src/WebSocketServer.ts | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/WebSocketServer.ts b/src/WebSocketServer.ts index e9ff411b..b7e97c56 100644 --- a/src/WebSocketServer.ts +++ b/src/WebSocketServer.ts @@ -312,6 +312,14 @@ class WebSocketServer { ...serverDefault, ...config, }; + // Config header names need to be set to lowercase + if (this.config.headers != null) { + const originalHeaders = this.config.headers; + this.config.headers = {}; + for (const [headerName, value] of Object.entries(originalHeaders)) { + this.config.headers[headerName.toLowerCase()] = value; + } + } this.resolveHostname = resolveHostname; this.connectTimeoutTime = connectTimeoutTime;