Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions src/WebSocketClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import type { ContextTimed, ContextTimedInput } from '@matrixai/contexts';
import { AbstractEvent } from '@matrixai/events';
import { createDestroy } from '@matrixai/async-init';
import Logger from '@matrixai/logger';
import WebSocket from 'ws';
import * as ws from 'ws';
Comment thread
amydevs marked this conversation as resolved.
import { EventAll } from '@matrixai/events';
import { context, timedCancellable } from '@matrixai/contexts/dist/decorators';
import * as errors from './errors';
Expand Down Expand Up @@ -62,6 +62,7 @@ class WebSocketClient {
reasonToCode?: StreamReasonToCode;
codeToReason?: StreamCodeToReason;
logger?: Logger;
_webSocketClass?: typeof globalThis.WebSocket | typeof ws.WebSocket;
},
ctx?: Partial<ContextTimedInput>,
): Promise<WebSocketClient>;
Expand All @@ -79,6 +80,9 @@ class WebSocketClient {
reasonToCode,
codeToReason,
logger = new Logger(`${this.name}`),
_webSocketClass = globalThis.WebSocket == null
? ws.WebSocket
: globalThis.WebSocket,
}: {
host: string;
port: number;
Expand All @@ -87,6 +91,7 @@ class WebSocketClient {
reasonToCode?: StreamReasonToCode;
codeToReason?: StreamCodeToReason;
logger?: Logger;
_webSocketClass?: typeof globalThis.WebSocket | typeof ws.WebSocket;
},
@context ctx: ContextTimed,
): Promise<WebSocketClient> {
Expand All @@ -106,16 +111,19 @@ class WebSocketClient {

const address = `wss://${utils.buildAddress(host_, port_)}`;

// RejectUnauthorized must be false when TLSVerifyCallback exists,
// This is so that verification can be deferred to the callback rather than the system installed Certs
const webSocket = new WebSocket(address, {
rejectUnauthorized:
wsConfig.verifyPeer && wsConfig.verifyCallback == null,
key: wsConfig.key as any,
cert: wsConfig.cert as any,
ca: wsConfig.ca as any,
headers: wsConfig.headers,
});
let webSocket: ws.WebSocket | typeof globalThis.WebSocket.prototype;
if (_webSocketClass === ws.WebSocket) {
webSocket = new ws.WebSocket(address, {
rejectUnauthorized:
wsConfig.verifyPeer && wsConfig.verifyCallback == null,
key: wsConfig.key as any,
cert: wsConfig.cert as any,
ca: wsConfig.ca as any,
headers: wsConfig.headers,
});
} else {
webSocket = new _webSocketClass(address);
}

const connectionId = 0;
const connection = new WebSocketConnection({
Expand Down
91 changes: 68 additions & 23 deletions src/WebSocketConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class WebSocketConnection {
* Internal native WebSocket object.
* @internal
*/
protected socket: ws.WebSocket;
protected socket: ws.WebSocket | typeof globalThis.WebSocket.prototype;

protected config: WebSocketConfig;

Expand Down Expand Up @@ -263,6 +263,12 @@ class WebSocketConnection {
this.streamMap.delete(stream.streamId);
};

protected handleBrowserSocketMessage = async (
event: MessageEvent<ArrayBuffer>,
) => {
return this.handleSocketMessage(event.data, true);
};

protected handleSocketMessage = async (
data: ws.RawData,
isBinary: boolean,
Expand Down Expand Up @@ -433,7 +439,7 @@ class WebSocketConnection {
);
};

protected handleSocketError = (err: Error) => {
protected handleSocketError = (err: any) => {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the cause can either be an Error or an event. It doesn't rlly matter what this value is, as it's just set as the cause of the thrown error

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using any type tends to be pretty hacky. Generally if the type isn't well defined and it doesn't really matter then unknown can be used instead of any.

While any can be anything and assigned to anything. unknown could be anything but can't be assigned to anything else. So it's safer to use. https://stackoverflow.com/questions/51439843/unknown-vs-any

const errorCode = utils.ConnectionErrorCode.InternalServerError;
const reason = 'An error occurred on the underlying WebSocket instance';
this.closeSocket(errorCode, reason);
Expand Down Expand Up @@ -518,7 +524,7 @@ class WebSocketConnection {
connectionId: number;
meta?: undefined;
config: WebSocketConfig;
socket: ws.WebSocket;
socket: ws.WebSocket | typeof globalThis.WebSocket.prototype;
reasonToCode?: StreamReasonToCode;
codeToReason?: StreamCodeToReason;
logger?: Logger;
Expand All @@ -528,13 +534,14 @@ class WebSocketConnection {
connectionId: number;
meta: ConnectionMetadata;
config: WebSocketConfig;
socket: ws.WebSocket;
socket: ws.WebSocket | typeof globalThis.WebSocket.prototype;
reasonToCode?: StreamReasonToCode;
codeToReason?: StreamCodeToReason;
logger?: Logger;
}) {
this.logger = logger ?? new Logger(`${this.constructor.name}`);
this.connectionId = connectionId;
socket.binaryType = 'arraybuffer';
this.socket = socket;
this.config = config;
this.type = type;
Expand Down Expand Up @@ -706,15 +713,27 @@ class WebSocketConnection {
}),
);
};
this.socket.once('error', openErrorHandler);
const openHandler = () => {
this.resolveSecureEstablishedP();
};
this.socket.once('open', openHandler);
// This will always happen, no need to remove the handler
this.socket.once('close', this.handleSocketClose);
if (utils.isNodeWebsocket(this.socket)) {
this.socket.once('error', openErrorHandler);
this.socket.once('open', openHandler);
// This will always happen, no need to remove the handler
this.socket.once('close', this.handleSocketClose);
} else {
this.socket.addEventListener('error', openErrorHandler, { once: true });
this.socket.addEventListener('open', openHandler, { once: true });
// This will always happen, no need to remove the handler
this.socket.addEventListener(
'close',
(event) =>
this.handleSocketClose(event.code, Buffer.from(event.reason)),
{ once: true },
);
}

if (this.type === 'client') {
if (this.type === 'client' && utils.isNodeWebsocket(this.socket)) {
this.socket.once('upgrade', async (request) => {
const tlsSocket = request.socket as TLSSocket;
const peerCert = tlsSocket.getPeerCertificate(true);
Expand Down Expand Up @@ -788,23 +807,39 @@ class WebSocketConnection {
);
}

this.socket.off('open', openHandler);
// Upgrade only exists on the ws library, we can use removeAllListeners without worrying
this.socket.removeAllListeners('upgrade');
if (utils.isNodeWebsocket(this.socket)) {
this.socket.off('open', openHandler);
// Upgrade only exists on the ws library, we can use removeAllListeners without worrying
this.socket.removeAllListeners('upgrade');
} else {
this.socket.removeEventListener('open', openHandler);
}

// Close the ws if it's open at this stage
await this.closedP;
throw e;
} finally {
ctx.signal.removeEventListener('abort', abortHandler);
// Upgrade has already been removed by being called once or by the catch
this.socket.off('error', openErrorHandler);
if (utils.isNodeWebsocket(this.socket)) {
// Upgrade has already been removed by being called once or by the catch
this.socket.off('error', openErrorHandler);
} else {
this.socket.removeEventListener('error', openErrorHandler);
}
}

// Set the connection up
this.socket.on('message', this.handleSocketMessage);
this.socket.on('ping', this.handleSocketPing);
this.socket.on('pong', this.handleSocketPong);
this.socket.once('error', this.handleSocketError);
if (utils.isNodeWebsocket(this.socket)) {
this.socket.on('message', this.handleSocketMessage);
this.socket.on('ping', this.handleSocketPing);
this.socket.on('pong', this.handleSocketPong);
this.socket.once('error', this.handleSocketError);
} else {
this.socket.addEventListener('message', this.handleBrowserSocketMessage);
this.socket.addEventListener('error', this.handleSocketError, {
once: true,
});
}

if (this.config.keepAliveIntervalTime != null) {
this.startKeepAliveIntervalTimer(this.config.keepAliveIntervalTime);
Expand Down Expand Up @@ -996,10 +1031,18 @@ class WebSocketConnection {
events.EventWebSocketConnectionClose.name,
this.handleEventWebSocketConnectionClose,
);
this.socket.off('message', this.handleSocketMessage);
this.socket.off('ping', this.handleSocketPing);
this.socket.off('pong', this.handleSocketPong);
this.socket.off('error', this.handleSocketError);
if (utils.isNodeWebsocket(this.socket)) {
this.socket.off('message', this.handleSocketMessage);
this.socket.off('ping', this.handleSocketPing);
this.socket.off('pong', this.handleSocketPong);
this.socket.off('error', this.handleSocketError);
} else {
this.socket.removeEventListener(
'message',
this.handleBrowserSocketMessage,
);
this.socket.removeEventListener('error', this.handleSocketError);
}

this.logger.info(`Stopped ${this.constructor.name}`);
}
Expand Down Expand Up @@ -1049,7 +1092,9 @@ class WebSocketConnection {
protected startKeepAliveIntervalTimer(ms: number): void {
const keepAliveHandler = async (signal: AbortSignal) => {
if (signal.aborted) return;
this.socket.ping();
if (utils.isNodeWebsocket(this.socket)) {
this.socket.ping();
}
Comment thread
amydevs marked this conversation as resolved.
this.keepAliveIntervalTimer = new Timer({
delay: ms,
handler: keepAliveHandler,
Expand Down
6 changes: 6 additions & 0 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { Callback, Host, Port, PromiseDeconstructed } from './types';
import type { DetailedPeerCertificate } from 'tls';
import * as dns from 'dns';
import { IPv4, IPv6, Validator } from 'ip-num';
import * as ws from 'ws';
import * as errors from './errors';

const textEncoder = new TextEncoder();
Expand All @@ -11,6 +12,10 @@ function never(message?: string): never {
throw new errors.ErrorWebSocketUndefinedBehaviour(message);
}

function isNodeWebsocket(websocket: any): websocket is ws.WebSocket {
return websocket.constructor === ws.WebSocket;
}

/**
* Is it an IPv4 address?
*/
Expand Down Expand Up @@ -467,6 +472,7 @@ export {
textEncoder,
textDecoder,
never,
isNodeWebsocket,
isIPv4,
isIPv6,
isIPv4MappedIPv6,
Expand Down
Loading