Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/WebSocketClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class WebSocketClient {
key: wsConfig.key as any,
cert: wsConfig.cert as any,
ca: wsConfig.ca as any,
headers: wsConfig.headers,
});

const connectionId = 0;
Expand Down
6 changes: 5 additions & 1 deletion src/WebSocketConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,11 @@ class WebSocketConnection {
const ca = utils.collectPEMs(this.config.ca).map(utils.pemToDER);
try {
if (this.config.verifyPeer && this.config.verifyCallback != null) {
await this.config.verifyCallback?.(peerCertChain, ca);
await this.config.verifyCallback?.(
peerCertChain,
ca,
request.headers,
);
}
this._localHost = request.connection.localAddress as Host;
this._localPort = request.connection.localPort as Port;
Expand Down
38 changes: 37 additions & 1 deletion src/WebSocketServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,28 @@ class WebSocketServer {
);
};

protected handleServerHeaders = (headers: Array<string>) => {
if (this.config.headers == null) {
return;
}
const configHeaders = { ...this.config.headers };
for (let i = 0; i < headers.length; i++) {
const headerKV = headers[i].split(': ', 2);
if (headerKV.length !== 2) {
continue;
}
const [headerName] = headerKV;
const lowercaseHeaderName = headerName.toLowerCase();
if (lowercaseHeaderName in configHeaders) {
headers[i] = `${headerName}: ${configHeaders[lowercaseHeaderName]}`;
delete configHeaders[lowercaseHeaderName];
}
}
for (const [header, value] of Object.entries(configHeaders)) {
headers.push(`${header}: ${value}`);
}
};

/**
* WebSocketServer.constructor
*
Expand Down Expand Up @@ -290,6 +312,14 @@ class WebSocketServer {
...serverDefault,
...config,
};
// Config header names need to be set to lowercase
if (this.config.headers != null) {
const originalHeaders = this.config.headers;
this.config.headers = {};
for (const [headerName, value] of Object.entries(originalHeaders)) {
this.config.headers[headerName.toLowerCase()] = value;
}
}
this.resolveHostname = resolveHostname;

this.connectTimeoutTime = connectTimeoutTime;
Expand Down Expand Up @@ -379,7 +409,11 @@ class WebSocketServer {
const peerCertChain = utils.toPeerCertChain(peerCert);
const ca = utils.collectPEMs(this.config.ca).map(utils.pemToDER);
try {
await this.config.verifyCallback(peerCertChain, ca);
await this.config.verifyCallback(
peerCertChain,
ca,
info.req.headers,
);
return done(true);
} catch (e) {
info.req.destroy(e);
Expand All @@ -391,6 +425,7 @@ class WebSocketServer {
});

this.webSocketServer.on('connection', this.handleServerConnection);
this.webSocketServer.on('headers', this.handleServerHeaders);
this.webSocketServer.on('close', this.handleWebSocketServerClosed);
this.server.on('close', this.handleServerClosed);
this.webSocketServer.on('error', this.handleServerError);
Expand Down Expand Up @@ -485,6 +520,7 @@ class WebSocketServer {
);

this.webSocketServer.off('connection', this.handleServerConnection);
this.webSocketServer.off('headers', this.handleServerHeaders);
this.webSocketServer.off('close', this.handleServerClosed);
this.server.off('close', this.handleServerClosed);
this.webSocketServer.off('error', this.handleServerError);
Expand Down
9 changes: 9 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import type { IncomingHttpHeaders } from 'http';

// Async

/**
Expand Down Expand Up @@ -79,6 +81,7 @@ type ConnectionMetadata = {
type TLSVerifyCallback = (
certs: Array<Uint8Array>,
ca: Array<Uint8Array>,
headers: IncomingHttpHeaders,
) => PromiseLike<void>;

type WebSocketConfig = {
Expand Down Expand Up @@ -116,6 +119,12 @@ type WebSocketConfig = {
*/
cert?: string | Array<string> | Uint8Array | Array<Uint8Array>;

/**
* Headers that will be attached to the HTTP Upgrade Request/Response.
* This can be used for authentication purposes, or to provide metadata regarding the connection.
*/
headers?: IncomingHttpHeaders;

/**
* Verify the other peer.
* Clients by default set this to true.
Expand Down
189 changes: 189 additions & 0 deletions tests/WebSocketClient.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,195 @@ describe(WebSocketClient.name, () => {
await server.stop();
});
});
describe('custom TLS verification with headers', () => {
test('server succeeds custom verification', async () => {
const tlsConfigs = await testsUtils.generateConfig('RSA');
const authorization = 'password';
const server = new WebSocketServer({
logger: logger.getChild(WebSocketServer.name),
config: {
key: tlsConfigs.key,
cert: tlsConfigs.cert,
verifyPeer: false,
headers: {
authorization,
},
},
});
const handleConnectionEventProm = promise<any>();
server.addEventListener(
events.EventWebSocketServerConnection.name,
handleConnectionEventProm.resolveP,
);
await server.start({
host: localhost,
});
// Connection should succeed
const verifyProm = promise<string | undefined>();
const client = await WebSocketClient.createWebSocketClient({
host: localhost,
port: server.port,
logger: logger.getChild(WebSocketClient.name),
config: {
verifyPeer: true,
verifyCallback: async (_certs, _ca, headers) => {
verifyProm.resolveP(headers.authorization);
},
},
});
await handleConnectionEventProm.p;
await expect(verifyProm.p).resolves.toBe(authorization);
await client.destroy();
await server.stop();
});
test('server fails custom verification', async () => {
const tlsConfigs = await testsUtils.generateConfig('RSA');
const authorization = 'password';
const server = new WebSocketServer({
logger: logger.getChild(WebSocketServer.name),
config: {
key: tlsConfigs.key,
cert: tlsConfigs.cert,
verifyPeer: false,
headers: {
authorization,
},
},
});
const handleConnectionEventProm = promise<WebSocketConnection>();
server.addEventListener(
events.EventWebSocketServerConnection.name,
(event: events.EventWebSocketServerConnection) =>
handleConnectionEventProm.resolveP(event.detail),
);
await server.start({
host: localhost,
});
// Connection should fail
const clientProm = WebSocketClient.createWebSocketClient({
host: localhost,
port: server.port,
logger: logger.getChild(WebSocketClient.name),
config: {
verifyPeer: true,
verifyCallback: () => {
throw Error('SOME ERROR');
},
},
});
clientProm.catch(() => {});

// Verification by peer happens after connection is securely established and started
const serverConn = await handleConnectionEventProm.p;
const serverErrorProm = promise<never>();
serverConn.addEventListener(
events.EventWebSocketConnectionError.name,
(evt: events.EventWebSocketConnectionError) =>
serverErrorProm.rejectP(evt.detail),
);
await expect(serverErrorProm.p).rejects.toThrow(
errors.ErrorWebSocketConnectionPeer,
);
await expect(clientProm).rejects.toThrow(
errors.ErrorWebSocketConnectionLocal,
);

await server.stop();
});
test('client succeeds custom verification', async () => {
const tlsConfigs = await testsUtils.generateConfig('RSA');
const authorization = 'password';
const verifyProm = promise<string | undefined>();
const server = new WebSocketServer({
logger: logger.getChild(WebSocketServer.name),
config: {
key: tlsConfigs.key,
cert: tlsConfigs.cert,
verifyPeer: true,
verifyCallback: async (_certs, _ca, headers) => {
verifyProm.resolveP(headers.authorization);
},
},
});
const handleConnectionEventProm = promise<any>();
server.addEventListener(
events.EventWebSocketServerConnection.name,
handleConnectionEventProm.resolveP,
);
await server.start({
host: localhost,
});
// Connection should succeed
const client = await WebSocketClient.createWebSocketClient({
host: localhost,
port: server.port,
logger: logger.getChild(WebSocketClient.name),
config: {
verifyPeer: false,
key: tlsConfigs.key,
cert: tlsConfigs.cert,
headers: {
authorization,
},
},
});
await handleConnectionEventProm.p;
await expect(verifyProm.p).resolves.toBe(authorization);
await client.destroy();
await server.stop();
});
test('client fails custom verification', async () => {
const tlsConfigs = await testsUtils.generateConfig('RSA');
const authorization = 'password';
const server = new WebSocketServer({
logger: logger.getChild(WebSocketServer.name),
config: {
key: tlsConfigs.key,
cert: tlsConfigs.cert,
verifyPeer: true,
verifyCallback: () => {
throw Error('SOME ERROR');
},
},
});
const handleConnectionEventProm = promise<WebSocketConnection>();
server.addEventListener(
events.EventWebSocketServerConnection.name,
(event: events.EventWebSocketServerConnection) =>
handleConnectionEventProm.resolveP(event.detail),
);
await server.start({
host: localhost,
});
// Connection should fail
await expect(
WebSocketClient.createWebSocketClient({
host: localhost,
port: server.port,
logger: logger.getChild(WebSocketClient.name),
config: {
key: tlsConfigs.key,
cert: tlsConfigs.cert,
verifyPeer: false,
headers: {
authorization,
},
},
}),
).rejects.toHaveProperty('name', 'ErrorWebSocketConnectionPeer');

// // Server connection is never emitted
await Promise.race([
handleConnectionEventProm.p.then(() => {
throw Error('Server connection should not be emitted');
}),
// Allow some time
testsUtils.sleep(200),
]);

await server.stop();
});
});
describe.each(types)('custom TLS verification with %s', (type) => {
test('server succeeds custom verification', async () => {
const tlsConfigs = await testsUtils.generateConfig(type);
Expand Down