Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/extension/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"@endo/promise-kit": "^1.1.6",
"@metamask/snaps-utils": "^8.3.0",
"@metamask/utils": "^9.3.0",
"@ocap/errors": "workspace:^",
"@ocap/kernel": "workspace:^",
"@ocap/shims": "workspace:^",
"@ocap/streams": "workspace:^",
Expand Down
30 changes: 17 additions & 13 deletions packages/extension/src/VatWorkerClient.test.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import '@ocap/shims/endoify';
import type { VatId } from '@ocap/kernel';
import type { VatId, VatWorkerServiceCommandReply } from '@ocap/kernel';
import { VatWorkerServiceCommandMethod } from '@ocap/kernel';
import { delay } from '@ocap/test-utils';
import type { Logger } from '@ocap/utils';
import { makeLogger } from '@ocap/utils';
import { describe, it, expect, beforeEach, vi } from 'vitest';

import { VatWorkerServiceMethod } from './vat-worker-service.js';
import type { ExtensionVatWorkerClient } from './VatWorkerClient.js';
import { makeTestClient } from '../test/vat-worker-service.js';

Expand Down Expand Up @@ -40,16 +40,18 @@ describe('ExtensionVatWorkerClient', () => {

it.each`
method
${VatWorkerServiceMethod.Init}
${VatWorkerServiceMethod.Delete}
${VatWorkerServiceCommandMethod.Launch}
${VatWorkerServiceCommandMethod.Terminate}
`(
"calls logger.error when receiving a $method reply it wasn't waiting for",
async ({ method }) => {
const errorSpy = vi.spyOn(clientLogger, 'error');
const unexpectedReply = {
method,
id: 9,
vatId: 'v0',
const unexpectedReply: VatWorkerServiceCommandReply = {
id: 'm9',
payload: {
method,
params: { vatId: 'v0' },
},
};
serverPort.postMessage(unexpectedReply);
await delay(100);
Expand All @@ -61,15 +63,17 @@ describe('ExtensionVatWorkerClient', () => {
},
);

it(`calls logger.error when receiving a ${VatWorkerServiceMethod.Init} reply without a port`, async () => {
it(`calls logger.error when receiving a ${VatWorkerServiceCommandMethod.Launch} reply without a port`, async () => {
const errorSpy = vi.spyOn(clientLogger, 'error');
const vatId: VatId = 'v0';
// eslint-disable-next-line @typescript-eslint/no-floating-promises
client.initWorker(vatId);
client.launch(vatId);
const reply = {
method: VatWorkerServiceMethod.Init,
id: 1,
vatId: 'v0',
id: 'm1',
payload: {
method: VatWorkerServiceCommandMethod.Launch,
params: { vatId: 'v0' },
},
};
serverPort.postMessage(reply);
await delay(100);
Expand Down
73 changes: 45 additions & 28 deletions packages/extension/src/VatWorkerClient.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
import { makePromiseKit } from '@endo/promise-kit';
import type { PromiseKit } from '@endo/promise-kit';
import { isObject } from '@metamask/utils';
import { unmarshalError } from '@ocap/errors';
import {
VatWorkerServiceCommandMethod,
isVatWorkerServiceCommandReply,
} from '@ocap/kernel';
import type {
StreamEnvelope,
StreamEnvelopeReply,
VatWorkerService,
VatId,
VatWorkerServiceCommand,
} from '@ocap/kernel';
import type { DuplexStream } from '@ocap/streams';
import { MessagePortDuplexStream } from '@ocap/streams';
import type { Logger } from '@ocap/utils';
import { makeCounter, makeHandledCallback, makeLogger } from '@ocap/utils';

import type { AddListener } from './vat-worker-service.js';
import {
isVatWorkerServiceMessage,
VatWorkerServiceMethod,
} from './vat-worker-service.js';
import type { AddListener, PostMessage } from './vat-worker-service.js';
// Appears in the docs.
// eslint-disable-next-line @typescript-eslint/no-unused-vars
import type { ExtensionVatWorkerServer } from './VatWorkerServer.js';
Expand All @@ -25,16 +28,19 @@ type PromiseCallbacks<Resolve = unknown> = Omit<PromiseKit<Resolve>, 'promise'>;
export class ExtensionVatWorkerClient implements VatWorkerService {
readonly #logger: Logger;

readonly #unresolvedMessages: Map<number, PromiseCallbacks> = new Map();
readonly #unresolvedMessages: Map<
VatWorkerServiceCommand['id'],
PromiseCallbacks
> = new Map();

readonly #messageCounter = makeCounter();

readonly #postMessage: (message: unknown) => void;
readonly #postMessage: PostMessage<VatWorkerServiceCommand>;

/**
* The client end of the vat worker service, intended to be constructed in
* the kernel worker. Sends initWorker and deleteWorker requests to the
* server and wraps the initWorker response in a DuplexStream for consumption
* the kernel worker. Sends launch and terminate worker requests to the
* server and wraps the launch response in a DuplexStream for consumption
* by the kernel.
*
* @see {@link ExtensionVatWorkerServer} for the other end of the service.
Expand All @@ -44,7 +50,7 @@ export class ExtensionVatWorkerClient implements VatWorkerService {
* @param logger - An optional {@link Logger}. Defaults to a new logger labeled '[vat worker client]'.
*/
constructor(
postMessage: (message: unknown) => void,
postMessage: PostMessage<VatWorkerServiceCommand>,
addListener: AddListener,
logger?: Logger,
) {
Expand All @@ -54,15 +60,11 @@ export class ExtensionVatWorkerClient implements VatWorkerService {
}

async #sendMessage<Return>(
method:
| typeof VatWorkerServiceMethod.Init
| typeof VatWorkerServiceMethod.Delete,
vatId: VatId,
payload: VatWorkerServiceCommand['payload'],
): Promise<Return> {
const message = {
id: this.#messageCounter(),
method,
vatId,
const message: VatWorkerServiceCommand = {
id: `m${this.#messageCounter()}`,
payload,
};
const { promise, resolve, reject } = makePromiseKit<Return>();
this.#unresolvedMessages.set(message.id, {
Expand All @@ -73,24 +75,38 @@ export class ExtensionVatWorkerClient implements VatWorkerService {
return promise;
}

async initWorker(
async launch(
vatId: VatId,
): Promise<DuplexStream<StreamEnvelopeReply, StreamEnvelope>> {
return this.#sendMessage(VatWorkerServiceMethod.Init, vatId);
return this.#sendMessage({
method: VatWorkerServiceCommandMethod.Launch,
params: { vatId },
});
}

async deleteWorker(vatId: VatId): Promise<undefined> {
return this.#sendMessage(VatWorkerServiceMethod.Delete, vatId);
async terminate(vatId: VatId): Promise<undefined> {
return this.#sendMessage({
method: VatWorkerServiceCommandMethod.Terminate,
params: { vatId },
});
}

async terminateAll(): Promise<void> {
return this.#sendMessage({
method: VatWorkerServiceCommandMethod.TerminateAll,
params: null,
});
}

async #handleMessage(event: MessageEvent<unknown>): Promise<void> {
if (!isVatWorkerServiceMessage(event.data)) {
if (!isVatWorkerServiceCommandReply(event.data)) {
// This happens when other messages pass through the same channel.
this.#logger.debug('Received unexpected message', event.data);
return;
}

const { id, method, error } = event.data;
const { id, payload } = event.data;
const { method } = payload;
const port = event.ports.at(0);

const promise = this.#unresolvedMessages.get(id);
Expand All @@ -100,13 +116,13 @@ export class ExtensionVatWorkerClient implements VatWorkerService {
return;
}

if (error) {
promise.reject(error);
if (isObject(payload.params) && payload.params.error) {
promise.reject(unmarshalError(payload.params.error));
return;
}

switch (method) {
case VatWorkerServiceMethod.Init:
case VatWorkerServiceCommandMethod.Launch:
if (!port) {
this.#logger.error('Expected a port with message reply', event);
return;
Expand All @@ -117,7 +133,8 @@ export class ExtensionVatWorkerClient implements VatWorkerService {
),
);
break;
case VatWorkerServiceMethod.Delete:
case VatWorkerServiceCommandMethod.Terminate:
case VatWorkerServiceCommandMethod.TerminateAll:
// If we were caching streams on the client this would be a good place
// to remove them.
promise.resolve(undefined);
Expand Down
97 changes: 70 additions & 27 deletions packages/extension/src/VatWorkerServer.test.ts
Original file line number Diff line number Diff line change
@@ -1,57 +1,100 @@
import '@ocap/shims/endoify';
import type { NonEmptyArray } from '@metamask/utils';
import { VatNotFoundError } from '@ocap/errors';
import { VatWorkerServiceCommandMethod } from '@ocap/kernel';
import { delay } from '@ocap/test-utils';
import type { Logger } from '@ocap/utils';
import { makeLogger } from '@ocap/utils';
import { describe, it, expect, beforeEach, vi } from 'vitest';

import type { VatWorker } from './vat-worker-service.js';
import type { ExtensionVatWorkerServer } from './VatWorkerServer.js';
import { makeTestServer } from '../test/vat-worker-service.js';

describe('VatWorker', () => {
describe('ExtensionVatWorkerServer', () => {
let serverPort: MessagePort;
let clientPort: MessagePort;

let logger: Logger;

let server: ExtensionVatWorkerServer;

// let vatPort: MessagePort;
let kernelPort: MessagePort;

beforeEach(() => {
const serviceMessageChannel = new MessageChannel();
serverPort = serviceMessageChannel.port1;
clientPort = serviceMessageChannel.port2;

logger = makeLogger('[test server]');
});

const deliveredMessageChannel = new MessageChannel();
// vatPort = deliveredMessageChannel.port1;
kernelPort = deliveredMessageChannel.port2;
describe('Misc', () => {
beforeEach(() => {
[server] = makeTestServer({ serverPort, logger });
});

server = makeTestServer({ serverPort, logger, kernelPort });
});
it('starts', () => {
server.start();
expect(serverPort.onmessage).toBeDefined();
});

it('starts', () => {
server.start();
expect(serverPort.onmessage).toBeDefined();
});
it('throws if started twice', () => {
server.start();
expect(() => server.start()).toThrow(/already running/u);
});

it('throws if started twice', () => {
server.start();
expect(() => server.start()).toThrow(/already running/u);
it('calls logger.debug when receiving an unexpected message', async () => {
const debugSpy = vi.spyOn(logger, 'debug');
const unexpectedMessage = 'foobar';
server.start();
clientPort.postMessage(unexpectedMessage);
await delay(100);
expect(debugSpy).toHaveBeenCalledOnce();
expect(debugSpy).toHaveBeenLastCalledWith(
'Received unexpected message',
unexpectedMessage,
);
});
});

it('calls logger.debug when receiving an unexpected message', async () => {
const debugSpy = vi.spyOn(logger, 'debug');
const unexpectedMessage = 'foobar';
server.start();
clientPort.postMessage(unexpectedMessage);
await delay(100);
expect(debugSpy).toHaveBeenCalledOnce();
expect(debugSpy).toHaveBeenLastCalledWith(
'Received unexpected message',
unexpectedMessage,
);
describe('terminateAll', () => {
let workers: NonEmptyArray<VatWorker>;

beforeEach(() => {
[server, ...workers] = makeTestServer({
serverPort,
logger,
nWorkers: 3,
});
});

it('calls logger.error when a vat fails to terminate', async () => {
const errorSpy = vi.spyOn(logger, 'error');
const vatId = 'v0';
const vatNotFoundError = new VatNotFoundError(vatId);
vi.spyOn(workers[0], 'terminate').mockRejectedValue(vatNotFoundError);
server.start();
clientPort.postMessage({
id: 'm0',
payload: {
method: VatWorkerServiceCommandMethod.Launch,
params: { vatId },
},
});
clientPort.postMessage({
id: 'm1',
payload: {
method: VatWorkerServiceCommandMethod.TerminateAll,
params: null,
},
});

await delay(100);

expect(errorSpy).toHaveBeenCalledOnce();
expect(errorSpy.mock.lastCall?.[0]).toBe(
`Error handling ${VatWorkerServiceCommandMethod.TerminateAll} for vatId ${vatId}`,
);
expect(errorSpy.mock.lastCall?.[1]).toBe(vatNotFoundError);
});
});
});
Loading