Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions src/RPCClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import * as utils from './utils';
const timerCleanupReasonSymbol = Symbol('timerCleanUpReasonSymbol');

class RPCClient<M extends ClientManifest> {
protected onTimeoutCallback?: () => void;
protected idGen: IdGen;
protected logger: Logger;
protected streamFactory: StreamFactory;
Expand All @@ -37,28 +36,30 @@ class RPCClient<M extends ClientManifest> {
Uint8Array
>;
protected callerTypes: Record<string, HandlerType>;
public registerOnTimeoutCallback(callback: () => void) {
this.onTimeoutCallback = callback;
}
// Method proxies
public readonly timeoutTime: number;
public readonly graceTime: number;
public readonly methodsProxy = new Proxy(
{},
{
get: (_, method) => {
if (typeof method === 'symbol') return;
switch (this.callerTypes[method]) {
case 'UNARY':
return (params, ctx) => this.unaryCaller(method, params, ctx);
return (params: JSONObject, ctx: Partial<ContextTimedInput>) =>
this.unaryCaller(method, params, ctx);
case 'SERVER':
return (params, ctx) =>
return (params: JSONObject, ctx: Partial<ContextTimedInput>) =>
this.serverStreamCaller(method, params, ctx);
case 'CLIENT':
return (ctx) => this.clientStreamCaller(method, ctx);
return (ctx: Partial<ContextTimedInput>) =>
this.clientStreamCaller(method, ctx);
case 'DUPLEX':
return (ctx) => this.duplexStreamCaller(method, ctx);
return (ctx: Partial<ContextTimedInput>) =>
this.duplexStreamCaller(method, ctx);
case 'RAW':
return (header, ctx) => this.rawStreamCaller(method, header, ctx);
return (header: JSONObject, ctx: Partial<ContextTimedInput>) =>
this.rawStreamCaller(method, header, ctx);
default:
return;
}
Expand Down Expand Up @@ -86,6 +87,7 @@ class RPCClient<M extends ClientManifest> {
streamFactory,
middlewareFactory = middleware.defaultClientMiddlewareWrapper(),
timeoutTime = Infinity,
graceTime = 1000,
logger,
toError = utils.toError,
idGen = () => null,
Expand All @@ -99,6 +101,7 @@ class RPCClient<M extends ClientManifest> {
Uint8Array
>;
timeoutTime?: number;
graceTime?: number;
logger?: Logger;
idGen?: IdGen;
toError?: ToError;
Expand All @@ -111,6 +114,7 @@ class RPCClient<M extends ClientManifest> {
this.streamFactory = streamFactory;
this.middlewareFactory = middlewareFactory;
this.timeoutTime = timeoutTime;
this.graceTime = graceTime;
this.logger = logger ?? new Logger(this.constructor.name);
this.toError = toError;
}
Expand Down Expand Up @@ -262,7 +266,9 @@ class RPCClient<M extends ClientManifest> {
} else {
timer = ctx.timer;
}
let timerGrace: Timer | undefined;
const cleanUp = () => {
if (timerGrace != null) timerGrace.cancel(timerCleanupReasonSymbol);
// Clean up the timer and signal
if (ctx.timer == null) timer.cancel(timerCleanupReasonSymbol);
if (ctx.signal != null) {
Expand All @@ -278,9 +284,6 @@ class RPCClient<M extends ClientManifest> {
void timer.then(
() => {
abortController.abort(timeoutError);
if (this.onTimeoutCallback) {
this.onTimeoutCallback();
}
},
() => {}, // Ignore cancellation error
);
Expand All @@ -298,7 +301,14 @@ class RPCClient<M extends ClientManifest> {
throw e;
}
void timer.then(
() => {
async () => {
timerGrace = new Timer({ delay: this.graceTime });
try {
await timerGrace;
} catch (e) {
if (e === timerCleanupReasonSymbol) return;
throw e;
}
rpcStream.cancel(
new errors.ErrorRPCTimedOut('RPC has timed out', {
cause: ctx.signal?.reason,
Expand Down Expand Up @@ -368,7 +378,6 @@ class RPCClient<M extends ClientManifest> {
* single RPC message that is sent to specify the method for the RPC call.
* Any metadata of extra parameters is provided here.
* @param ctx - ContextTimed used for timeouts and cancellation.
* @param id - Id is generated only once, and used throughout the stream for the rest of the communication
*/
public async rawStreamCaller(
method: string,
Expand Down
9 changes: 1 addition & 8 deletions src/RPCServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ interface RPCServer extends startStop.StartStop {}
eventStopped: events.EventRPCServerStopped,
})
class RPCServer {
protected onTimeoutCallback?: () => void;
protected idGen: IdGen;
protected logger: Logger;
protected handlerMap: Map<string, RawHandlerImplementation> = new Map();
Expand All @@ -68,14 +67,11 @@ class RPCServer {
Uint8Array,
JSONRPCResponseSuccess
>;
// Function to register a callback for timeout
public registerOnTimeoutCallback(callback: () => void) {
this.onTimeoutCallback = callback;
}

/**
* RPCServer Constructor
*
* @param obj
* @param obj.middlewareFactory - Middleware used to process the rpc messages.
* The middlewareFactory needs to be a function that creates a pair of
* transform streams that convert `Uint8Array` to `JSONRPCRequest` on the forward
Expand Down Expand Up @@ -464,9 +460,6 @@ class RPCServer {
delay: this.timeoutTime,
handler: () => {
abortController.abort(new errors.ErrorRPCTimedOut());
if (this.onTimeoutCallback) {
this.onTimeoutCallback();
}
},
});

Expand Down
6 changes: 3 additions & 3 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ type JSONRPCRequestMetadata = Partial<{
}>;

/**
* `T` is the the params you want to specify.
* `T` is the params you want to specify.
*
* `M` is the metadata you want to specify.
*
Expand All @@ -141,7 +141,7 @@ type JSONRPCResponseMetadata = Partial<{
}>;

/**
* `T` is the the result you want to specify.
* `T` is the result you want to specify.
*
* `M` is the metadata you want to specify.
*
Expand Down Expand Up @@ -251,7 +251,7 @@ interface RPCStream<R, W, M extends POJO = POJO>
}

/**
* This is a factory for creating a `RPCStream` when making a RPC call.
* This is a factory for creating a `RPCStream` when making an RPC call.
* The transport mechanism is a black box to the RPC system. So long as it is
* provided as a RPCStream the RPC system should function. It is assumed that
* the RPCStream communicates with an `RPCServer`.
Expand Down
9 changes: 4 additions & 5 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ import type {
ClientManifest,
HandlerType,
JSONObject,
JSONRPCResponseError,
JSONRPCMessage,
JSONRPCRequest,
JSONRPCRequestMessage,
JSONRPCRequestNotification,
JSONRPCResponse,
JSONRPCResponseError,
JSONRPCResponseFailed,
JSONRPCResponseSuccess,
JSONValue,
Expand All @@ -22,7 +22,7 @@ import * as errors from './errors';

const timeoutCancelledReason = Symbol('timeoutCancelledReason');

// Importing PK funcs and utils which are essential for RPC
// Importing PK functions and utils which are essential for RPC
function isObject(o: unknown): o is object {
return o !== null && typeof o === 'object';
}
Expand Down Expand Up @@ -222,7 +222,7 @@ function parseJSONRPCMessage<T extends JSONObject>(
* @throws {TypeError} If the error is an instance of {@link Symbol}, {@link BigInt} or {@link Function}.
*/
function fromError(error: any): JSONValue {
// TODO: Linked-List traversal must be done iteractively rather than recusively to prevent stack overflow.
// TODO: Linked-List traversal must be done interactively rather than recursively to prevent stack overflow.
switch (typeof error) {
case 'symbol':
case 'bigint':
Expand Down Expand Up @@ -385,10 +385,9 @@ function toError(
e.cause = toError(errorData.data.cause, clientMetadata, false);
}
if (top) {
const err = new errors.ErrorRPCRemote(clientMetadata, undefined, {
return new errors.ErrorRPCRemote(clientMetadata, undefined, {
cause: e,
});
return err;
} else {
return e;
}
Expand Down
Loading