diff --git a/packages/extension/test/e2e/control-panel.test.ts b/packages/extension/test/e2e/control-panel.test.ts index 5614e9f70..c3d1e22fb 100644 --- a/packages/extension/test/e2e/control-panel.test.ts +++ b/packages/extension/test/e2e/control-panel.test.ts @@ -116,14 +116,13 @@ test.describe('Control Panel', () => { await expect(popupPage.locator('table tr')).toHaveCount(3); }); - // TODO: Fix this test once the ping method is implemented - test.skip('should ping a vat', async () => { + test('should ping a vat', async () => { await expect( popupPage.locator('td button:text("Ping")').first(), ).toBeVisible(); await popupPage.locator('td button:text("Ping")').first().click(); - await expect(messageOutput).toContainText('"method": "ping",'); - await expect(messageOutput).toContainText('{"result":"pong"}'); + await expect(messageOutput).toContainText('"method": "pingVat",'); + await expect(messageOutput).toContainText('pong'); }); test('should terminate all vats', async () => { diff --git a/packages/kernel-store/src/sqlite/common.test.ts b/packages/kernel-store/src/sqlite/common.test.ts index 5fa1a4fdb..852fdfbe9 100644 --- a/packages/kernel-store/src/sqlite/common.test.ts +++ b/packages/kernel-store/src/sqlite/common.test.ts @@ -1,6 +1,6 @@ import { describe, it, expect } from 'vitest'; -import { SQL_QUERIES } from './common.ts'; +import { SQL_QUERIES, assertSafeIdentifier } from './common.ts'; describe('SQL_QUERIES', () => { // XXX Is this test actually useful? It's basically testing that the source code matches itself. @@ -40,10 +40,12 @@ describe('SQL_QUERIES', () => { it('has all expected query properties', () => { expect(Object.keys(SQL_QUERIES).sort()).toStrictEqual([ 'ABORT_TRANSACTION', + 'BEGIN_IMMEDIATE_TRANSACTION', 'BEGIN_TRANSACTION', 'CLEAR', 'CLEAR_VS', 'COMMIT_TRANSACTION', + 'CREATE_SAVEPOINT', 'CREATE_TABLE', 'CREATE_TABLE_VS', 'DELETE', @@ -54,8 +56,46 @@ describe('SQL_QUERIES', () => { 'GET', 'GET_ALL_VS', 'GET_NEXT', + 'RELEASE_SAVEPOINT', + 'ROLLBACK_SAVEPOINT', 'SET', 'SET_VS', ]); }); }); + +describe('assertSafeIdentifier', () => { + it('accepts valid SQL identifiers', () => { + expect(() => assertSafeIdentifier('valid')).not.toThrow(); + expect(() => assertSafeIdentifier('Valid')).not.toThrow(); + expect(() => assertSafeIdentifier('valid_name')).not.toThrow(); + expect(() => assertSafeIdentifier('valid_name_123')).not.toThrow(); + expect(() => assertSafeIdentifier('_leading_underscore')).not.toThrow(); + }); + + it('rejects invalid SQL identifiers', () => { + // Starting with a number + expect(() => assertSafeIdentifier('123invalid')).toThrow( + 'Invalid identifier', + ); + + // Containing invalid characters + expect(() => assertSafeIdentifier('invalid-name')).toThrow( + 'Invalid identifier', + ); + expect(() => assertSafeIdentifier('invalid.name')).toThrow( + 'Invalid identifier', + ); + expect(() => assertSafeIdentifier('invalid;name')).toThrow( + 'Invalid identifier', + ); + expect(() => assertSafeIdentifier('invalid name')).toThrow( + 'Invalid identifier', + ); + + // Containing SQL injection attempts + expect(() => assertSafeIdentifier("name'; DROP TABLE users--")).toThrow( + 'Invalid identifier', + ); + }); +}); diff --git a/packages/kernel-store/src/sqlite/common.ts b/packages/kernel-store/src/sqlite/common.ts index 860c7d042..672b07d9d 100644 --- a/packages/kernel-store/src/sqlite/common.ts +++ b/packages/kernel-store/src/sqlite/common.ts @@ -52,30 +52,35 @@ export const SQL_QUERIES = { DELETE FROM kv_vatstore WHERE vatID = ? `, - CLEAR: ` - DELETE FROM kv - `, - CLEAR_VS: ` - DELETE FROM kv_vatstore - `, - DROP: ` - DROP TABLE kv - `, - DROP_VS: ` - DROP TABLE kv_vatstore - `, - BEGIN_TRANSACTION: ` - BEGIN TRANSACTION - `, - COMMIT_TRANSACTION: ` - COMMIT TRANSACTION - `, - ABORT_TRANSACTION: ` - ROLLBACK TRANSACTION - `, + CLEAR: `DELETE FROM kv`, + CLEAR_VS: `DELETE FROM kv_vatstore`, + DROP: `DROP TABLE kv`, + DROP_VS: `DROP TABLE kv_vatstore`, + BEGIN_TRANSACTION: `BEGIN TRANSACTION`, + BEGIN_IMMEDIATE_TRANSACTION: `BEGIN IMMEDIATE TRANSACTION`, + COMMIT_TRANSACTION: `COMMIT TRANSACTION`, + ABORT_TRANSACTION: `ROLLBACK TRANSACTION`, + // SQLite's parameter markers (?, ?NNN, :name, @name, $name) can only be used + // in places where a literal value is allowed. We can't bind identifiers + // for table names, column names, or savepoint names. We use %NAME% as a + // placeholder for the savepoint name. + CREATE_SAVEPOINT: `SAVEPOINT %NAME%`, + ROLLBACK_SAVEPOINT: `ROLLBACK TO SAVEPOINT %NAME%`, + RELEASE_SAVEPOINT: `RELEASE SAVEPOINT %NAME%`, } as const; /** * The default filename for the SQLite database; ":memory:" is an ephemeral in-memory database. */ export const DEFAULT_DB_FILENAME = ':memory:'; + +/** + * Check if a string is a valid SQLite identifier. + * + * @param name - The string to check. + */ +export function assertSafeIdentifier(name: string): void { + if (!/^[A-Za-z_]\w*$/u.test(name)) { + throw new Error(`Invalid identifier: ${name}`); + } +} diff --git a/packages/kernel-store/src/sqlite/nodejs.test.ts b/packages/kernel-store/src/sqlite/nodejs.test.ts index a3727e423..afe27845a 100644 --- a/packages/kernel-store/src/sqlite/nodejs.test.ts +++ b/packages/kernel-store/src/sqlite/nodejs.test.ts @@ -25,6 +25,10 @@ const mockStatement = { const mockDb = { prepare: vi.fn(() => mockStatement), transaction: vi.fn((fn) => fn), + exec: vi.fn(), + inTransaction: false, + // eslint-disable-next-line @typescript-eslint/naming-convention + _spStack: [] as string[], }; vi.mock('better-sqlite3', () => ({ @@ -137,6 +141,49 @@ describe('makeSQLKernelDatabase', () => { expect(mockStatement.run).toHaveBeenCalled(); // commit transaction }); + describe('deleteVatStore functionality', () => { + beforeEach(() => { + Object.values(mockStatement).forEach((mock) => mock.mockReset()); + }); + + it('deleteVatStore removes all data for a given vat', async () => { + const db = await makeSQLKernelDatabase({}); + const vatId = 'test-vat'; + db.deleteVatStore(vatId); + expect(mockDb.prepare).toHaveBeenCalledWith(SQL_QUERIES.DELETE_VS_ALL); + expect(mockStatement.run).toHaveBeenCalledWith(vatId); + }); + + it('deleteVatStore handles empty vatId correctly', async () => { + const db = await makeSQLKernelDatabase({}); + db.deleteVatStore(''); + expect(mockStatement.run).toHaveBeenCalledWith(''); + }); + + it("deleteVatStore doesn't affect other vat stores", async () => { + const db = await makeSQLKernelDatabase({}); + db.makeVatStore('vat1'); + const vatStore2 = db.makeVatStore('vat2'); + db.deleteVatStore('vat1'); + mockStatement.iterate.mockReturnValueOnce([ + { key: 'testKey', value: 'testValue' }, + ]); + const data = vatStore2.getKVData(); + expect(data).toStrictEqual([['testKey', 'testValue']]); + expect(mockStatement.iterate).toHaveBeenCalledWith('vat2'); + }); + + it('deleteVatStore handles errors correctly', async () => { + const db = await makeSQLKernelDatabase({}); + mockStatement.run.mockImplementationOnce(() => { + throw new Error('Database error during delete'); + }); + expect(() => db.deleteVatStore('test-vat')).toThrow( + 'Database error during delete', + ); + }); + }); + describe('getDBFilename', () => { it('returns in-memory database path when label starts with ":"', async () => { const result = await getDBFilename(':memory:'); @@ -151,4 +198,124 @@ describe('makeSQLKernelDatabase', () => { }); }); }); + + describe('savepoint functionality', () => { + beforeEach(() => { + mockDb.exec.mockClear(); + mockDb.inTransaction = false; + mockDb._spStack = []; + }); + + it('creates a savepoint using sanitized name', async () => { + const db = await makeSQLKernelDatabase({}); + db.createSavepoint('valid_name'); + + expect(mockDb.exec).toHaveBeenCalledWith('SAVEPOINT valid_name'); + }); + + it('rejects invalid savepoint names', async () => { + const db = await makeSQLKernelDatabase({}); + expect(() => db.createSavepoint('invalid-name')).toThrow( + 'Invalid identifier', + ); + expect(() => db.createSavepoint('123numeric')).toThrow( + 'Invalid identifier', + ); + expect(() => db.createSavepoint('spaces not allowed')).toThrow( + 'Invalid identifier', + ); + expect(() => db.createSavepoint("point'; DROP TABLE kv--")).toThrow( + 'Invalid identifier', + ); + expect(mockDb.exec).not.toHaveBeenCalledWith( + expect.stringContaining('DROP TABLE'), + ); + }); + + it('rolls back to a savepoint', async () => { + const db = await makeSQLKernelDatabase({}); + db.createSavepoint('test_point'); + db.rollbackSavepoint('test_point'); + expect(mockDb.exec).toHaveBeenCalledWith( + 'ROLLBACK TO SAVEPOINT test_point', + ); + }); + + it('releases a savepoint', async () => { + const db = await makeSQLKernelDatabase({}); + db.createSavepoint('test_point'); + db.releaseSavepoint('test_point'); + expect(mockDb.exec).toHaveBeenCalledWith('RELEASE SAVEPOINT test_point'); + }); + + it('createSavepoint begins transaction if needed', async () => { + const db = await makeSQLKernelDatabase({}); + db.createSavepoint('test_point'); + expect(mockDb._spStack).toContain('test_point'); + expect(mockDb.exec).toHaveBeenCalledWith('SAVEPOINT test_point'); + }); + + it('rollbackSavepoint validates savepoint exists', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb.inTransaction = true; + mockDb._spStack = ['existing_point']; + expect(() => db.rollbackSavepoint('nonexistent_point')).toThrow( + 'No such savepoint: nonexistent_point', + ); + }); + + it('rollbackSavepoint removes all points after target', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb.inTransaction = true; + mockDb._spStack = ['point1', 'point2', 'point3']; + db.rollbackSavepoint('point2'); + expect(mockDb._spStack).toStrictEqual(['point1']); + expect(mockDb.exec).toHaveBeenCalledWith('ROLLBACK TO SAVEPOINT point2'); + }); + + it('rollbackSavepoint closes transaction if no savepoints remain', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb.inTransaction = true; + mockDb._spStack = ['point1']; + db.rollbackSavepoint('point1'); + expect(mockDb._spStack).toStrictEqual([]); + }); + + it('releaseSavepoint validates savepoint exists', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb.inTransaction = true; + mockDb._spStack = ['existing_point']; + expect(() => db.releaseSavepoint('nonexistent_point')).toThrow( + 'No such savepoint: nonexistent_point', + ); + }); + + it('releaseSavepoint removes all points after target', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb.inTransaction = true; + mockDb._spStack = ['point1', 'point2', 'point3']; + db.releaseSavepoint('point2'); + expect(mockDb._spStack).toStrictEqual(['point1']); + expect(mockDb.exec).toHaveBeenCalledWith('RELEASE SAVEPOINT point2'); + }); + + it('releaseSavepoint commits transaction if no savepoints remain', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb.inTransaction = true; + mockDb._spStack = ['point1']; + db.releaseSavepoint('point1'); + expect(mockDb._spStack).toStrictEqual([]); + }); + + it('supports nested savepoints', async () => { + const db = await makeSQLKernelDatabase({}); + db.createSavepoint('outer'); + db.createSavepoint('inner'); + expect(mockDb._spStack).toStrictEqual(['outer', 'inner']); + db.rollbackSavepoint('inner'); + expect(mockDb._spStack).toStrictEqual(['outer']); + db.releaseSavepoint('outer'); + expect(mockDb._spStack).toStrictEqual([]); + }); + }); }); diff --git a/packages/kernel-store/src/sqlite/nodejs.ts b/packages/kernel-store/src/sqlite/nodejs.ts index 456173913..3d9eb1051 100644 --- a/packages/kernel-store/src/sqlite/nodejs.ts +++ b/packages/kernel-store/src/sqlite/nodejs.ts @@ -1,15 +1,24 @@ import { Logger } from '@metamask/logger'; -import type { Database } from 'better-sqlite3'; +import type { Database as SqliteDatabase } from 'better-sqlite3'; // eslint-disable-next-line @typescript-eslint/naming-convention import Sqlite from 'better-sqlite3'; import { mkdir } from 'fs/promises'; import { tmpdir } from 'os'; import { join } from 'path'; -import { SQL_QUERIES, DEFAULT_DB_FILENAME } from './common.ts'; +import { + SQL_QUERIES, + DEFAULT_DB_FILENAME, + assertSafeIdentifier, +} from './common.ts'; import { getDBFolder } from './env.ts'; import type { KVStore, VatStore, KernelDatabase } from '../types.ts'; +export type Database = SqliteDatabase & { + // stack of active savepoint names + _spStack: string[]; +}; + /** * Ensure that SQLite is initialized. * @@ -25,11 +34,13 @@ async function initDB( ): Promise { const dbPath = await getDBFilename(dbFilename); logger.debug('dbPath:', dbPath); - return new Sqlite(dbPath, { + const db = new Sqlite(dbPath, { verbose: (verbose ? logger.info.bind(logger) : undefined) as | ((...args: unknown[]) => void) | undefined, - }); + }) as Database; + db._spStack = []; + return db; } /** @@ -140,6 +151,45 @@ export async function makeSQLKernelDatabase({ const sqlKVClear = db.prepare(SQL_QUERIES.CLEAR); const sqlKVClearVS = db.prepare(SQL_QUERIES.CLEAR_VS); + const sqlVatstoreGetAll = db.prepare(SQL_QUERIES.GET_ALL_VS); + const sqlVatstoreSet = db.prepare(SQL_QUERIES.SET_VS); + const sqlVatstoreDelete = db.prepare(SQL_QUERIES.DELETE_VS); + const sqlVatstoreDeleteAll = db.prepare(SQL_QUERIES.DELETE_VS_ALL); + const sqlBeginTransaction = db.prepare(SQL_QUERIES.BEGIN_TRANSACTION); + const sqlCommitTransaction = db.prepare(SQL_QUERIES.COMMIT_TRANSACTION); + const sqlAbortTransaction = db.prepare(SQL_QUERIES.ABORT_TRANSACTION); + + /** + * Begin a transaction if not already in one + * + * @returns True if a new transaction was started, false if already in one + */ + function beginIfNeeded(): boolean { + if (db.inTransaction) { + return false; + } + sqlBeginTransaction.run(); + return true; + } + + /** + * Commit a transaction if one is active and no savepoints remain + */ + function commitIfNeeded(): void { + if (db.inTransaction && db._spStack.length === 0) { + sqlCommitTransaction.run(); + } + } + + /** + * Rollback a transaction + */ + function rollbackIfNeeded(): void { + if (db.inTransaction) { + sqlAbortTransaction.run(); + db._spStack.length = 0; + } + } /** * Delete everything from the database. @@ -160,11 +210,6 @@ export async function makeSQLKernelDatabase({ return query.all() as Record[]; } - const sqlVatstoreGetAll = db.prepare(SQL_QUERIES.GET_ALL_VS); - const sqlVatstoreSet = db.prepare(SQL_QUERIES.SET_VS); - const sqlVatstoreDelete = db.prepare(SQL_QUERIES.DELETE_VS); - const sqlVatstoreDeleteAll = db.prepare(SQL_QUERIES.DELETE_VS_ALL); - /** * Create a new VatStore for a vat. * @@ -223,12 +268,69 @@ export async function makeSQLKernelDatabase({ sqlVatstoreDeleteAll.run(vatId); } + /** + * Create a savepoint in the database. + * + * @param name - The name of the savepoint. + */ + function createSavepoint(name: string): void { + // We must be in a transaction when creating the savepoint or releasing it + // later will cause an autocommit. + // See https://github.com/Agoric/agoric-sdk/issues/8423 + beginIfNeeded(); + assertSafeIdentifier(name); + const query = SQL_QUERIES.CREATE_SAVEPOINT.replace('%NAME%', name); + db.exec(query); + db._spStack.push(name); + } + + /** + * Rollback to a savepoint in the database. + * + * @param name - The name of the savepoint. + */ + function rollbackSavepoint(name: string): void { + assertSafeIdentifier(name); + const idx = db._spStack.lastIndexOf(name); + if (idx < 0) { + throw new Error(`No such savepoint: ${name}`); + } + const query = SQL_QUERIES.ROLLBACK_SAVEPOINT.replace('%NAME%', name); + db.exec(query); + db._spStack.splice(idx); + if (db._spStack.length === 0) { + rollbackIfNeeded(); + } + } + + /** + * Release a savepoint in the database. + * + * @param name - The name of the savepoint. + */ + function releaseSavepoint(name: string): void { + assertSafeIdentifier(name); + const idx = db._spStack.lastIndexOf(name); + if (idx < 0) { + throw new Error(`No such savepoint: ${name}`); + } + const query = SQL_QUERIES.RELEASE_SAVEPOINT.replace('%NAME%', name); + db.exec(query); + db._spStack.splice(idx); + if (db._spStack.length === 0) { + commitIfNeeded(); + } + } + return { kernelKVStore: kvStore, executeQuery: kvExecuteQuery, clear: db.transaction(kvClear), makeVatStore, deleteVatStore, + createSavepoint, + rollbackSavepoint, + releaseSavepoint, }; } diff --git a/packages/kernel-store/src/sqlite/wasm.test.ts b/packages/kernel-store/src/sqlite/wasm.test.ts index 69a828d12..64f167c46 100644 --- a/packages/kernel-store/src/sqlite/wasm.test.ts +++ b/packages/kernel-store/src/sqlite/wasm.test.ts @@ -28,6 +28,10 @@ const mockStatement = { const mockDb = { exec: vi.fn(), prepare: vi.fn(() => mockStatement), + // eslint-disable-next-line @typescript-eslint/naming-convention + _inTx: false, + // eslint-disable-next-line @typescript-eslint/naming-convention + _spStack: [] as string[], }; const OpfsDbMock = vi.fn(() => mockDb); const DBMock = vi.fn(() => mockDb); @@ -364,4 +368,235 @@ describe('makeSQLKernelDatabase', () => { expect(mockStatement.reset).toHaveBeenCalled(); }); }); + + describe('savepoint functionality', () => { + beforeEach(() => { + mockDb.exec.mockClear(); + mockDb._inTx = false; + mockDb._spStack = []; + }); + + it('creates a savepoint using sanitized name', async () => { + const db = await makeSQLKernelDatabase({}); + db.createSavepoint('valid_name'); + + expect(mockDb.exec).toHaveBeenCalledWith('SAVEPOINT valid_name'); + }); + + it('rejects invalid savepoint names', async () => { + const db = await makeSQLKernelDatabase({}); + expect(() => db.createSavepoint('invalid-name')).toThrow( + 'Invalid identifier', + ); + expect(() => db.createSavepoint('123numeric')).toThrow( + 'Invalid identifier', + ); + expect(() => db.createSavepoint('spaces not allowed')).toThrow( + 'Invalid identifier', + ); + expect(() => db.createSavepoint("point'; DROP TABLE kv--")).toThrow( + 'Invalid identifier', + ); + expect(mockDb.exec).not.toHaveBeenCalledWith( + expect.stringContaining('DROP TABLE'), + ); + }); + + it('rolls back to a savepoint', async () => { + const db = await makeSQLKernelDatabase({}); + db.createSavepoint('test_point'); + db.rollbackSavepoint('test_point'); + expect(mockDb.exec).toHaveBeenCalledWith( + 'ROLLBACK TO SAVEPOINT test_point', + ); + }); + + it('releases a savepoint', async () => { + const db = await makeSQLKernelDatabase({}); + db.createSavepoint('test_point'); + db.releaseSavepoint('test_point'); + expect(mockDb.exec).toHaveBeenCalledWith('RELEASE SAVEPOINT test_point'); + }); + + it('createSavepoint begins transaction if needed', async () => { + const db = await makeSQLKernelDatabase({}); + db.createSavepoint('test_point'); + expect(mockDb._inTx).toBe(true); + expect(mockDb._spStack).toContain('test_point'); + expect(mockDb.exec).toHaveBeenCalledWith('SAVEPOINT test_point'); + }); + + it('rollbackSavepoint validates savepoint exists', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb._inTx = true; + mockDb._spStack = ['existing_point']; + expect(() => db.rollbackSavepoint('nonexistent_point')).toThrow( + 'No such savepoint: nonexistent_point', + ); + }); + + it('rollbackSavepoint removes all points after target', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb._inTx = true; + mockDb._spStack = ['point1', 'point2', 'point3']; + db.rollbackSavepoint('point2'); + expect(mockDb._spStack).toStrictEqual(['point1']); + expect(mockDb.exec).toHaveBeenCalledWith('ROLLBACK TO SAVEPOINT point2'); + }); + + it('rollbackSavepoint closes transaction if no savepoints remain', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb._inTx = true; + mockDb._spStack = ['point1']; + db.rollbackSavepoint('point1'); + expect(mockDb._spStack).toStrictEqual([]); + expect(mockDb._inTx).toBe(false); + }); + + it('releaseSavepoint validates savepoint exists', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb._inTx = true; + mockDb._spStack = ['existing_point']; + expect(() => db.releaseSavepoint('nonexistent_point')).toThrow( + 'No such savepoint: nonexistent_point', + ); + }); + + it('releaseSavepoint removes all points after target', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb._inTx = true; + mockDb._spStack = ['point1', 'point2', 'point3']; + db.releaseSavepoint('point2'); + expect(mockDb._spStack).toStrictEqual(['point1']); + expect(mockDb.exec).toHaveBeenCalledWith('RELEASE SAVEPOINT point2'); + }); + + it('releaseSavepoint commits transaction if no savepoints remain', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb._inTx = true; + mockDb._spStack = ['point1']; + db.releaseSavepoint('point1'); + expect(mockDb._spStack).toStrictEqual([]); + expect(mockDb._inTx).toBe(false); + }); + + it('supports nested savepoints', async () => { + const db = await makeSQLKernelDatabase({}); + db.createSavepoint('outer'); + db.createSavepoint('inner'); + expect(mockDb._spStack).toStrictEqual(['outer', 'inner']); + db.rollbackSavepoint('inner'); + expect(mockDb._spStack).toStrictEqual(['outer']); + expect(mockDb._inTx).toBe(true); + db.releaseSavepoint('outer'); + expect(mockDb._spStack).toStrictEqual([]); + expect(mockDb._inTx).toBe(false); + }); + }); + + it('deleteVatStore removes all data for a given vat', async () => { + Object.values(mockStatement).forEach((mock) => { + if (typeof mock === 'function' && mock.mockReset) { + mock.mockReset(); + } + }); + const db = await makeSQLKernelDatabase({}); + const vatId = 'test-vat'; + db.deleteVatStore(vatId); + expect(mockDb.prepare).toHaveBeenCalledWith(SQL_QUERIES.DELETE_VS_ALL); + expect(mockStatement.bind).toHaveBeenCalledWith([vatId]); + expect(mockStatement.step).toHaveBeenCalled(); + expect(mockStatement.reset).toHaveBeenCalled(); + }); + + it('deleteVatStore handles errors correctly', async () => { + Object.values(mockStatement).forEach((mock) => { + if (typeof mock === 'function' && mock.mockReset) { + mock.mockReset(); + } + }); + mockStatement.step.mockImplementationOnce(() => { + throw new Error('Database error'); + }); + const db = await makeSQLKernelDatabase({}); + expect(() => db.deleteVatStore('test-vat')).toThrow('Database error'); + expect(mockStatement.bind).toHaveBeenCalled(); + expect(mockStatement.reset).not.toHaveBeenCalled(); + }); + + it('deleteVatStore handles empty vatId correctly', async () => { + Object.values(mockStatement).forEach((mock) => { + if (typeof mock === 'function' && mock.mockReset) { + mock.mockReset(); + } + }); + + const db = await makeSQLKernelDatabase({}); + db.deleteVatStore(''); + expect(mockStatement.bind).toHaveBeenCalledWith(['']); + expect(mockStatement.step).toHaveBeenCalled(); + expect(mockStatement.reset).toHaveBeenCalled(); + }); + + it("deleteVatStore doesn't affect other vat stores", async () => { + Object.values(mockStatement).forEach((mock) => { + if (typeof mock === 'function' && mock.mockReset) { + mock.mockReset(); + } + }); + + const db = await makeSQLKernelDatabase({}); + db.makeVatStore('vat1'); + const vatStore2 = db.makeVatStore('vat2'); + db.deleteVatStore('vat1'); + mockStatement.step.mockReturnValueOnce(true).mockReturnValueOnce(false); + mockStatement.getString + .mockReturnValueOnce('testKey') + .mockReturnValueOnce('testValue'); + + const data = vatStore2.getKVData(); + expect(mockStatement.bind).toHaveBeenCalledWith(['vat2']); + expect(data).toStrictEqual([['testKey', 'testValue']]); + }); +}); + +describe('transaction management', () => { + beforeEach(() => { + Object.values(mockStatement).forEach((mock) => { + if (typeof mock === 'function' && mock.mockReset) { + mock.mockReset(); + } + }); + mockDb.exec.mockReset(); + mockDb._inTx = false; + mockDb._spStack = []; + }); + + it('safeMutate rollbacks transaction on error', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb._inTx = false; + mockDb._spStack = []; + mockStatement.step.mockImplementationOnce(() => { + throw new Error('Database error'); + }); + const vatStore = db.makeVatStore('test-vat'); + expect(() => vatStore.updateKVData([['key', 'value']], [])).toThrow( + 'Database error', + ); + expect(mockStatement.step).toHaveBeenCalled(); + }); + + it('safeMutate does not commit if already in transaction', async () => { + const db = await makeSQLKernelDatabase({}); + mockDb._inTx = true; + mockDb._spStack = []; + const vatStore = db.makeVatStore('test-vat'); + vatStore.updateKVData([['key', 'value']], []); + expect(mockStatement.step).not.toHaveBeenCalledWith( + expect.objectContaining({ sql: 'BEGIN TRANSACTION' }), + ); + expect(mockStatement.step).not.toHaveBeenCalledWith( + expect.objectContaining({ sql: 'COMMIT TRANSACTION' }), + ); + }); }); diff --git a/packages/kernel-store/src/sqlite/wasm.ts b/packages/kernel-store/src/sqlite/wasm.ts index d407e417a..f52af874d 100644 --- a/packages/kernel-store/src/sqlite/wasm.ts +++ b/packages/kernel-store/src/sqlite/wasm.ts @@ -1,11 +1,21 @@ import { Logger } from '@metamask/logger'; -import type { Database, PreparedStatement } from '@sqlite.org/sqlite-wasm'; +import type { Database as SqliteDatabase } from '@sqlite.org/sqlite-wasm'; import sqlite3InitModule from '@sqlite.org/sqlite-wasm'; -import { DEFAULT_DB_FILENAME, SQL_QUERIES } from './common.ts'; +import { + DEFAULT_DB_FILENAME, + assertSafeIdentifier, + SQL_QUERIES, +} from './common.ts'; import { getDBFolder } from './env.ts'; import type { KVStore, VatStore, KernelDatabase } from '../types.ts'; +export type Database = SqliteDatabase & { + _inTx: boolean; + // stack of active savepoint names + _spStack: string[]; +}; + /** * Ensure that SQLite is initialized. * @@ -14,7 +24,7 @@ import type { KVStore, VatStore, KernelDatabase } from '../types.ts'; */ export async function initDB(dbFilename: string): Promise { const sqlite3 = await sqlite3InitModule(); - let db: Database; + let db: SqliteDatabase; if (sqlite3.oo1.OpfsDb) { const dbName = dbFilename.startsWith(':') @@ -26,23 +36,11 @@ export async function initDB(dbFilename: string): Promise { db = new sqlite3.oo1.DB(`:memory:`, 'cw'); } - return db; -} + const dbWithTx = db as Database; + dbWithTx._inTx = false; + dbWithTx._spStack = []; -/** - * Helper function to paper over SQLite-wasm awfulness. Runs a prepared - * statement as it would be run in a more sensible API. - * - * @param stmt - A prepared statement to run. - * @param bindings - Optional parameters to bind for execution. - */ -// eslint-disable-next-line @typescript-eslint/no-unused-vars -function run(stmt: PreparedStatement, ...bindings: string[]): void { - if (bindings && bindings.length > 0) { - stmt.bind(bindings); - } - stmt.step(); - stmt.reset(); + return dbWithTx; } /** @@ -177,6 +175,71 @@ export async function makeSQLKernelDatabase({ const sqlKVClear = db.prepare(SQL_QUERIES.CLEAR); const sqlKVClearVS = db.prepare(SQL_QUERIES.CLEAR_VS); + const sqlVatstoreGetAll = db.prepare(SQL_QUERIES.GET_ALL_VS); + const sqlVatstoreSet = db.prepare(SQL_QUERIES.SET_VS); + const sqlVatstoreDelete = db.prepare(SQL_QUERIES.DELETE_VS); + const sqlVatstoreDeleteAll = db.prepare(SQL_QUERIES.DELETE_VS_ALL); + const sqlBeginTransaction = db.prepare(SQL_QUERIES.BEGIN_TRANSACTION); + const sqlCommitTransaction = db.prepare(SQL_QUERIES.COMMIT_TRANSACTION); + const sqlAbortTransaction = db.prepare(SQL_QUERIES.ABORT_TRANSACTION); + + /** + * Begin a transaction if not already in one + * + * @returns True if a new transaction was started, false if already in one + */ + function beginIfNeeded(): boolean { + if (db._inTx) { + return false; + } + sqlBeginTransaction.step(); + sqlBeginTransaction.reset(); + db._inTx = true; + return true; + } + + /** + * Commit a transaction if one is active and no savepoints remain + */ + function commitIfNeeded(): void { + if (db._inTx && db._spStack.length === 0) { + sqlCommitTransaction.step(); + sqlCommitTransaction.reset(); + db._inTx = false; + } + } + + /** + * Rollback a transaction + */ + function rollbackIfNeeded(): void { + if (db._inTx) { + sqlAbortTransaction.step(); + sqlAbortTransaction.reset(); + db._inTx = false; + db._spStack.length = 0; + } + } + + /** + * Safely mutate the database with proper transaction management + * + * @param mutator - Function that performs the database mutations + */ + function safeMutate(mutator: () => void): void { + const startedTx = beginIfNeeded(); + try { + mutator(); + if (startedTx) { + commitIfNeeded(); + } + } catch (error) { + if (startedTx) { + rollbackIfNeeded(); + } + throw error; + } + } /** * Delete everything from the database. @@ -218,14 +281,6 @@ export async function makeSQLKernelDatabase({ return results; } - const sqlVatstoreGetAll = db.prepare(SQL_QUERIES.GET_ALL_VS); - const sqlVatstoreSet = db.prepare(SQL_QUERIES.SET_VS); - const sqlVatstoreDelete = db.prepare(SQL_QUERIES.DELETE_VS); - const sqlVatstoreDeleteAll = db.prepare(SQL_QUERIES.DELETE_VS_ALL); - const sqlBeginTransaction = db.prepare(SQL_QUERIES.BEGIN_TRANSACTION); - const sqlCommitTransaction = db.prepare(SQL_QUERIES.COMMIT_TRANSACTION); - const sqlAbortTransaction = db.prepare(SQL_QUERIES.ABORT_TRANSACTION); - /** * Create a new VatStore for a vat. * @@ -261,9 +316,7 @@ export async function makeSQLKernelDatabase({ * @param deletes - A set of keys that have been deleted. */ function updateKVData(sets: [string, string][], deletes: string[]): void { - try { - sqlBeginTransaction.step(); - sqlBeginTransaction.reset(); + safeMutate(() => { for (const [key, value] of sets) { sqlVatstoreSet.bind([vatID, key, value]); sqlVatstoreSet.step(); @@ -274,13 +327,7 @@ export async function makeSQLKernelDatabase({ sqlVatstoreDelete.step(); sqlVatstoreDelete.reset(); } - sqlCommitTransaction.step(); - sqlCommitTransaction.reset(); - } catch (problem) { - sqlAbortTransaction.step(); - sqlAbortTransaction.reset(); - throw problem; - } + }); } return { @@ -300,11 +347,68 @@ export async function makeSQLKernelDatabase({ sqlVatstoreDeleteAll.reset(); } + /** + * Create a savepoint in the database. + * + * @param name - The name of the savepoint. + */ + function createSavepoint(name: string): void { + // We must be in a transaction when creating the savepoint or releasing it + // later will cause an autocommit. + // See https://github.com/Agoric/agoric-sdk/issues/8423 + beginIfNeeded(); + assertSafeIdentifier(name); + const query = SQL_QUERIES.CREATE_SAVEPOINT.replace('%NAME%', name); + db.exec(query); + db._spStack.push(name); + } + + /** + * Rollback to a savepoint in the database. + * + * @param name - The name of the savepoint. + */ + function rollbackSavepoint(name: string): void { + assertSafeIdentifier(name); + const idx = db._spStack.lastIndexOf(name); + if (idx < 0) { + throw new Error(`No such savepoint: ${name}`); + } + const query = SQL_QUERIES.ROLLBACK_SAVEPOINT.replace('%NAME%', name); + db.exec(query); + db._spStack.splice(idx); + if (db._spStack.length === 0) { + rollbackIfNeeded(); + } + } + + /** + * Release a savepoint in the database. + * + * @param name - The name of the savepoint. + */ + function releaseSavepoint(name: string): void { + assertSafeIdentifier(name); + const idx = db._spStack.lastIndexOf(name); + if (idx < 0) { + throw new Error(`No such savepoint: ${name}`); + } + const query = SQL_QUERIES.RELEASE_SAVEPOINT.replace('%NAME%', name); + db.exec(query); + db._spStack.splice(idx); + if (db._spStack.length === 0) { + commitIfNeeded(); + } + } + return { kernelKVStore: kvStore, clear: kvClear, executeQuery, makeVatStore, deleteVatStore, + createSavepoint, + rollbackSavepoint, + releaseSavepoint, }; } diff --git a/packages/kernel-store/src/types.ts b/packages/kernel-store/src/types.ts index 64fd0c6e7..d37bd524b 100644 --- a/packages/kernel-store/src/types.ts +++ b/packages/kernel-store/src/types.ts @@ -30,4 +30,7 @@ export type KernelDatabase = { clear(): void; makeVatStore(vatID: string): VatStore; deleteVatStore(vatID: string): void; + createSavepoint(name: string): void; + rollbackSavepoint(name: string): void; + releaseSavepoint(name: string): void; }; diff --git a/packages/kernel-test/src/exo.test.ts b/packages/kernel-test/src/exo.test.ts index 4c06f7c79..19ae1faf3 100644 --- a/packages/kernel-test/src/exo.test.ts +++ b/packages/kernel-test/src/exo.test.ts @@ -1,4 +1,3 @@ -import '@metamask/kernel-shims/endoify'; import { makeSQLKernelDatabase } from '@metamask/kernel-store/sqlite/nodejs'; import { waitUntilQuiescent } from '@metamask/kernel-utils'; import type { LogEntry } from '@metamask/logger'; diff --git a/packages/kernel-test/src/garbage-collection.test.ts b/packages/kernel-test/src/garbage-collection.test.ts index d6beb1e28..bd8687d7e 100644 --- a/packages/kernel-test/src/garbage-collection.test.ts +++ b/packages/kernel-test/src/garbage-collection.test.ts @@ -1,4 +1,3 @@ -import '@metamask/kernel-shims/endoify'; import type { KernelDatabase } from '@metamask/kernel-store'; import { makeSQLKernelDatabase } from '@metamask/kernel-store/sqlite/nodejs'; import { waitUntilQuiescent } from '@metamask/kernel-utils'; diff --git a/packages/kernel-test/src/liveslots.test.ts b/packages/kernel-test/src/liveslots.test.ts index 7209bcabd..c8148ea9d 100644 --- a/packages/kernel-test/src/liveslots.test.ts +++ b/packages/kernel-test/src/liveslots.test.ts @@ -1,4 +1,3 @@ -import '@metamask/kernel-shims/endoify'; import { makeSQLKernelDatabase } from '@metamask/kernel-store/sqlite/nodejs'; import { waitUntilQuiescent } from '@metamask/kernel-utils'; import type { LogEntry } from '@metamask/logger'; diff --git a/packages/kernel-test/src/logger.test.ts b/packages/kernel-test/src/logger.test.ts index 19f65fa62..29a154188 100644 --- a/packages/kernel-test/src/logger.test.ts +++ b/packages/kernel-test/src/logger.test.ts @@ -1,4 +1,3 @@ -import '@metamask/kernel-shims/endoify'; import { makeSQLKernelDatabase } from '@metamask/kernel-store/sqlite/nodejs'; import { waitUntilQuiescent } from '@metamask/kernel-utils'; import type { VatId } from '@metamask/ocap-kernel'; diff --git a/packages/kernel-test/src/resume.test.ts b/packages/kernel-test/src/resume.test.ts index 2489e3ec9..ad2b702ff 100644 --- a/packages/kernel-test/src/resume.test.ts +++ b/packages/kernel-test/src/resume.test.ts @@ -1,4 +1,3 @@ -import '@metamask/kernel-shims/endoify'; import { makeSQLKernelDatabase } from '@metamask/kernel-store/sqlite/nodejs'; import { waitUntilQuiescent } from '@metamask/kernel-utils'; import { describe, expect, it } from 'vitest'; diff --git a/packages/kernel-test/src/savepoint.test.ts b/packages/kernel-test/src/savepoint.test.ts new file mode 100644 index 000000000..7ce4c80b1 --- /dev/null +++ b/packages/kernel-test/src/savepoint.test.ts @@ -0,0 +1,82 @@ +import { makeSQLKernelDatabase } from '@metamask/kernel-store/sqlite/nodejs'; +import { describe, it, expect, vi } from 'vitest'; + +/** + * Helper to create a test database with some initial data + * + * @returns A SQLite database instance with initial data + */ +async function setupTestDb() { + const db = await makeSQLKernelDatabase({ + dbFilename: ':memory:', + label: 'savepoint-test', + }); + const { kernelKVStore } = db; + kernelKVStore.set('key1', 'value1'); + kernelKVStore.set('key2', 'value2'); + return db; +} + +describe('Savepoint functionality', () => { + it('allows creating and releasing a savepoint', async () => { + const db = await setupTestDb(); + db.createSavepoint('test_point'); + db.kernelKVStore.set('key3', 'value3'); + db.releaseSavepoint('test_point'); + expect(db.kernelKVStore.get('key3')).toBe('value3'); + }); + + it('can rollback to a savepoint to undo changes', async () => { + const db = await setupTestDb(); + expect(db.kernelKVStore.get('key1')).toBe('value1'); + expect(db.kernelKVStore.get('key2')).toBe('value2'); + db.createSavepoint('test_point'); + db.kernelKVStore.set('key1', 'modified1'); + db.kernelKVStore.set('key3', 'value3'); + db.kernelKVStore.delete('key2'); + expect(db.kernelKVStore.get('key1')).toBe('modified1'); + expect(db.kernelKVStore.get('key2')).toBeUndefined(); + expect(db.kernelKVStore.get('key3')).toBe('value3'); + db.rollbackSavepoint('test_point'); + expect(db.kernelKVStore.get('key1')).toBe('value1'); + expect(db.kernelKVStore.get('key2')).toBe('value2'); + expect(db.kernelKVStore.get('key3')).toBeUndefined(); + }); + + it('supports nested savepoints', async () => { + const db = await setupTestDb(); + db.createSavepoint('outer'); + db.kernelKVStore.set('key3', 'value3'); + db.createSavepoint('inner'); + db.kernelKVStore.set('key4', 'value4'); + expect(db.kernelKVStore.get('key3')).toBe('value3'); + expect(db.kernelKVStore.get('key4')).toBe('value4'); + db.rollbackSavepoint('inner'); + expect(db.kernelKVStore.get('key3')).toBe('value3'); + expect(db.kernelKVStore.get('key4')).toBeUndefined(); + db.releaseSavepoint('outer'); + expect(db.kernelKVStore.get('key3')).toBe('value3'); + }); + + it('rejects invalid savepoint names', async () => { + const db = await setupTestDb(); + expect(() => db.createSavepoint('invalid-name')).toThrow( + 'Invalid identifier', + ); + expect(() => db.createSavepoint('123numeric')).toThrow( + 'Invalid identifier', + ); + expect(() => db.createSavepoint('spaces not allowed')).toThrow( + 'Invalid identifier', + ); + }); + + it('sanitizes savepoint names to prevent SQL injection', async () => { + const db = await setupTestDb(); + const executeQuerySpy = vi.spyOn(db, 'executeQuery'); + expect(() => db.createSavepoint("point'; DROP TABLE kv--")).toThrow( + 'Invalid identifier', + ); + expect(executeQuerySpy).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/kernel-test/src/supervisor.test.ts b/packages/kernel-test/src/supervisor.test.ts index 8355b6824..614df1768 100644 --- a/packages/kernel-test/src/supervisor.test.ts +++ b/packages/kernel-test/src/supervisor.test.ts @@ -1,4 +1,3 @@ -import '@metamask/kernel-shims/endoify'; import { delay } from '@metamask/kernel-utils'; import type { JsonRpcMessage } from '@metamask/kernel-utils'; import type { VatConfig } from '@metamask/ocap-kernel'; diff --git a/packages/kernel-test/vitest.config.ts b/packages/kernel-test/vitest.config.ts index 1f60917fc..08608da08 100644 --- a/packages/kernel-test/vitest.config.ts +++ b/packages/kernel-test/vitest.config.ts @@ -1,3 +1,4 @@ +import path from 'path'; import { defineProject, mergeConfig } from 'vitest/config'; import defaultConfig from '../../vitest.config.ts'; @@ -8,6 +9,7 @@ const config = mergeConfig( test: { name: 'kernel-test', pool: 'forks', + setupFiles: path.resolve(__dirname, '../kernel-shims/src/endoify.js'), }, }), ); diff --git a/packages/ocap-kernel/src/Kernel.test.ts b/packages/ocap-kernel/src/Kernel.test.ts index baf85a236..b30c6bd0d 100644 --- a/packages/ocap-kernel/src/Kernel.test.ts +++ b/packages/ocap-kernel/src/Kernel.test.ts @@ -335,7 +335,6 @@ describe('Kernel', () => { await kernel.launchVat(config); const vats = kernel.getVats(); expect(vats).toHaveLength(1); - console.log(vats); expect(vats).toStrictEqual([ { id: 'v1', diff --git a/packages/ocap-kernel/src/Kernel.ts b/packages/ocap-kernel/src/Kernel.ts index 8b84b69af..665f87fae 100644 --- a/packages/ocap-kernel/src/Kernel.ts +++ b/packages/ocap-kernel/src/Kernel.ts @@ -2,6 +2,7 @@ import type { CapData } from '@endo/marshal'; import { StreamReadError, VatAlreadyExistsError, + VatDeletedError, VatNotFoundError, } from '@metamask/kernel-errors'; import { RpcService } from '@metamask/kernel-rpc-methods'; @@ -91,7 +92,10 @@ export class Kernel { if (options.resetStorage) { this.#resetKernelState(); } - this.#kernelQueue = new KernelQueue(this.#kernelStore); + this.#kernelQueue = new KernelQueue( + this.#kernelStore, + this.terminateVat.bind(this), + ); this.#kernelRouter = new KernelRouter( this.#kernelStore, this.#kernelQueue, @@ -302,14 +306,29 @@ export class Kernel { * @param vatId - The ID of the vat. * @param terminating - If true, the vat is being killed, if false, it's being * restarted. + * @param reason - If the vat is being terminated, the reason for the termination. */ - async #stopVat(vatId: VatId, terminating: boolean): Promise { + async #stopVat( + vatId: VatId, + terminating: boolean, + reason?: CapData, + ): Promise { const vat = this.#getVat(vatId); if (!vat) { throw new VatNotFoundError(vatId); } - await vat.terminate(terminating); - await this.#vatWorkerService.terminate(vatId).catch(this.#logger.error); + + let terminationError: Error | undefined; + if (reason) { + terminationError = new Error(`Vat termination: ${reason.body}`); + } else if (terminating) { + terminationError = new VatDeletedError(vatId); + } + + await this.#vatWorkerService + .terminate(vatId, terminationError) + .catch(this.#logger.error); + await vat.terminate(terminating, terminationError); this.#vats.delete(vatId); } @@ -317,9 +336,10 @@ export class Kernel { * Terminate a vat with extreme prejudice. * * @param vatId - The ID of the vat. + * @param reason - If the vat is being terminated, the reason for the termination. */ - async terminateVat(vatId: VatId): Promise { - await this.#stopVat(vatId, true); + async terminateVat(vatId: VatId, reason?: CapData): Promise { + await this.#stopVat(vatId, true, reason); this.#kernelStore.deleteVatConfig(vatId); // Mark for deletion (which will happen later, in vat-cleanup events) this.#kernelStore.markVatAsTerminated(vatId); diff --git a/packages/ocap-kernel/src/KernelQueue.test.ts b/packages/ocap-kernel/src/KernelQueue.test.ts index d0fcbc26b..75c54b59b 100644 --- a/packages/ocap-kernel/src/KernelQueue.test.ts +++ b/packages/ocap-kernel/src/KernelQueue.test.ts @@ -5,7 +5,9 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; import type { MockInstance } from 'vitest'; import { KernelQueue } from './KernelQueue.ts'; +import * as gc from './services/garbage-collection.ts'; import type { KernelStore } from './store/index.ts'; +import * as types from './types.ts'; import type { KRef, Message, @@ -25,6 +27,7 @@ describe('KernelQueue', () => { let kernelStore: KernelStore; let kernelQueue: KernelQueue; let mockPromiseKit: ReturnType; + let terminateVat: (vatId: string, reason?: CapData) => Promise; beforeEach(() => { mockPromiseKit = { @@ -33,6 +36,9 @@ describe('KernelQueue', () => { reject: vi.fn(), }; (makePromiseKit as unknown as MockInstance).mockReturnValue(mockPromiseKit); + + terminateVat = vi.fn().mockResolvedValue(undefined); + kernelStore = { nextTerminatedVatCleanup: vi.fn(), collectGarbage: vi.fn(), @@ -45,9 +51,13 @@ describe('KernelQueue', () => { resolveKernelPromise: vi.fn(), nextReapAction: vi.fn().mockReturnValue(null), getGCActions: vi.fn().mockReturnValue([]), + startCrank: vi.fn(), + endCrank: vi.fn(), + createCrankSavepoint: vi.fn(), + rollbackCrank: vi.fn(), } as unknown as KernelStore; - kernelQueue = new KernelQueue(kernelStore); + kernelQueue = new KernelQueue(kernelStore, terminateVat); }); describe('run', () => { @@ -63,11 +73,80 @@ describe('KernelQueue', () => { (kernelStore.dequeueRun as unknown as MockInstance).mockReturnValue( mockItem, ); + const processGCActionSetSpy = vi.spyOn(gc, 'processGCActionSet'); const deliverError = new Error('stop'); const deliver = vi.fn().mockRejectedValue(deliverError); await expect(kernelQueue.run(deliver)).rejects.toBe(deliverError); + expect(kernelStore.startCrank).toHaveBeenCalled(); + expect(kernelStore.createCrankSavepoint).toHaveBeenCalledWith('start'); + expect(processGCActionSetSpy).toHaveBeenCalled(); + expect(kernelStore.nextReapAction).toHaveBeenCalled(); expect(kernelStore.nextTerminatedVatCleanup).toHaveBeenCalled(); expect(deliver).toHaveBeenCalledWith(mockItem); + expect(kernelStore.endCrank).toHaveBeenCalled(); + }); + + it('rolls back crank when deliver returns abort', async () => { + const mockItem: RunQueueItem = { + type: 'send', + target: 'ko123', + message: {} as Message, + }; + (kernelStore.runQueueLength as unknown as MockInstance) + .mockReturnValueOnce(1) + .mockReturnValue(0); + (kernelStore.dequeueRun as unknown as MockInstance).mockReturnValueOnce( + mockItem, + ); + const deliver = vi.fn().mockResolvedValue({ abort: true }); + const collectGarbageError = new Error( + 'wakeUpTheRunQueue function already set', + ); + (kernelStore.collectGarbage as unknown as MockInstance).mockRejectedValue( + collectGarbageError, + ); + await expect(kernelQueue.run(deliver)).rejects.toThrow( + collectGarbageError.message, + ); + expect(kernelStore.startCrank).toHaveBeenCalled(); + expect(kernelStore.createCrankSavepoint).toHaveBeenCalledWith('start'); + expect(deliver).toHaveBeenCalledWith(mockItem); + expect(kernelStore.rollbackCrank).toHaveBeenCalledWith('start'); + expect(kernelStore.collectGarbage).toHaveBeenCalled(); + expect(kernelStore.endCrank).toHaveBeenCalled(); + }); + + it('terminates vat when deliver returns terminate', async () => { + const mockItem: RunQueueItem = { + type: 'send', + target: 'ko123', + message: {} as Message, + }; + const terminateInfo = { vatId: 'v1', info: { body: '"test"' } }; + (kernelStore.runQueueLength as unknown as MockInstance) + .mockReturnValueOnce(1) + .mockReturnValue(0); + (kernelStore.dequeueRun as unknown as MockInstance).mockReturnValueOnce( + mockItem, + ); + const deliver = vi.fn().mockResolvedValue({ terminate: terminateInfo }); + const collectGarbageError = new Error( + 'wakeUpTheRunQueue function already set', + ); + (kernelStore.collectGarbage as unknown as MockInstance).mockRejectedValue( + collectGarbageError, + ); + await expect(kernelQueue.run(deliver)).rejects.toThrow( + collectGarbageError.message, + ); + expect(kernelStore.startCrank).toHaveBeenCalled(); + expect(deliver).toHaveBeenCalledWith(mockItem); + expect(terminateVat).toHaveBeenCalledWith( + terminateInfo.vatId, + terminateInfo.info, + ); + expect(kernelStore.collectGarbage).toHaveBeenCalled(); + expect(kernelStore.endCrank).toHaveBeenCalled(); }); }); @@ -235,6 +314,78 @@ describe('KernelQueue', () => { expect(kernelQueue.subscriptions.has(kpid)).toBe(false); }); + it('handles resolutions with undefined vatId (kernel decider)', () => { + const kpid = 'kp123'; + const resolution: VatOneResolution = [ + kpid, + false, + { body: 'resolved value', slots: ['slot1'] } as CapData, + ]; + (kernelStore.getKernelPromise as unknown as MockInstance).mockReturnValue( + { + state: 'unresolved', + decider: undefined, + subscribers: ['v2'], + }, + ); + const resolveHandler = vi.fn(); + kernelQueue.subscriptions.set(kpid, resolveHandler); + const insistVatIdSpy = vi.spyOn(types, 'insistVatId'); + kernelQueue.resolvePromises(undefined, [resolution]); + expect(insistVatIdSpy).not.toHaveBeenCalled(); + expect(kernelStore.incrementRefCount).toHaveBeenCalledWith( + kpid, + 'resolve|kpid', + ); + expect(kernelStore.incrementRefCount).toHaveBeenCalledWith( + 'slot1', + 'resolve|slot', + ); + expect(kernelStore.enqueueRun).toHaveBeenCalledWith({ + type: 'notify', + vatId: 'v2', + kpid, + }); + expect(kernelStore.resolveKernelPromise).toHaveBeenCalledWith( + kpid, + false, + resolution[2], + ); + expect(resolveHandler).toHaveBeenCalledWith(resolution[2]); + expect(kernelQueue.subscriptions.has(kpid)).toBe(false); + insistVatIdSpy.mockRestore(); + }); + + it('handles promises with no subscribers', () => { + const vatId = 'v1'; + const kpid = 'kpNoSubscribers'; + const resolution: VatOneResolution = [ + kpid, + false, + { body: 'resolved value', slots: [] } as CapData, + ]; + (kernelStore.getKernelPromise as unknown as MockInstance).mockReturnValue( + { + state: 'unresolved', + decider: vatId, + subscribers: [], + }, + ); + const resolveHandler = vi.fn(); + kernelQueue.subscriptions.set(kpid, resolveHandler); + kernelQueue.resolvePromises(vatId, [resolution]); + expect(kernelStore.enqueueRun).not.toHaveBeenCalledWith( + expect.objectContaining({ type: 'notify' }), + ); + expect(kernelStore.resolveKernelPromise).toHaveBeenCalledWith( + kpid, + false, + resolution[2], + ); + expect(resolveHandler).toHaveBeenCalledWith(resolution[2]); + expect(kernelQueue.subscriptions.has(kpid)).toBe(false); + }); + it('throws error if a promise is already resolved', () => { const vatId = 'v1'; const kpid = 'kp123'; diff --git a/packages/ocap-kernel/src/KernelQueue.ts b/packages/ocap-kernel/src/KernelQueue.ts index 2a2e9af92..280c2a4d6 100644 --- a/packages/ocap-kernel/src/KernelQueue.ts +++ b/packages/ocap-kernel/src/KernelQueue.ts @@ -7,6 +7,7 @@ import { kser } from './services/kernel-marshal.ts'; import type { KernelStore } from './store/index.ts'; import { insistVatId } from './types.ts'; import type { + CrankResults, KRef, Message, RunQueueItem, @@ -26,14 +27,24 @@ export class KernelQueue { /** Storage holding the kernel's own persistent state */ readonly #kernelStore: KernelStore; + /** A function that terminates a vat. */ + readonly #terminateVat: ( + vatId: VatId, + reason?: CapData, + ) => Promise; + /** Message results that the kernel itself has subscribed to */ readonly subscriptions: Map) => void> = new Map(); /** Thunk to signal run queue transition from empty to non-empty */ #wakeUpTheRunQueue: (() => void) | null; - constructor(kernelStore: KernelStore) { + constructor( + kernelStore: KernelStore, + terminateVat: (vatId: VatId, reason?: CapData) => Promise, + ) { this.#kernelStore = kernelStore; + this.#terminateVat = terminateVat; this.#wakeUpTheRunQueue = null; } @@ -43,10 +54,27 @@ export class KernelQueue { * * @param deliver - A function that delivers an item to the kernel. */ - async run(deliver: (item: RunQueueItem) => Promise): Promise { + async run( + deliver: (item: RunQueueItem) => Promise, + ): Promise { for await (const item of this.#runQueueItems()) { this.#kernelStore.nextTerminatedVatCleanup(); - await deliver(item); + const crankResults = await deliver(item); + if (crankResults?.abort) { + // Rollback the kernel state to before the failed delivery attempt. + // For active vats, this allows the message to be retried in a future crank. + // For terminated vats, the message will just go splat. + this.#kernelStore.rollbackCrank('start'); + // TODO: Currently all errors terminate the vat, but instead we could + // restart it and terminate the vat only after a certain number of failed + // retries. This is probably where we should implement the vat restart logic. + } + // Vat termination during delivery is triggered by an illegal syscall + // or by syscall.exit(). + if (crankResults?.terminate) { + const { vatId, info } = crankResults.terminate; + await this.#terminateVat(vatId, info); + } this.#kernelStore.collectGarbage(); } } @@ -58,34 +86,40 @@ export class KernelQueue { */ async *#runQueueItems(): AsyncGenerator { for (;;) { - const gcAction = processGCActionSet(this.#kernelStore); - if (gcAction) { - yield gcAction; - continue; - } + this.#kernelStore.startCrank(); + try { + this.#kernelStore.createCrankSavepoint('start'); + const gcAction = processGCActionSet(this.#kernelStore); + if (gcAction) { + yield gcAction; + continue; + } - const reapAction = this.#kernelStore.nextReapAction(); - if (reapAction) { - yield reapAction; - continue; - } + const reapAction = this.#kernelStore.nextReapAction(); + if (reapAction) { + yield reapAction; + continue; + } - while (this.#kernelStore.runQueueLength() > 0) { - const item = this.#kernelStore.dequeueRun(); - if (item) { - yield item; - } else { - break; + while (this.#kernelStore.runQueueLength() > 0) { + const item = this.#kernelStore.dequeueRun(); + if (item) { + yield item; + } else { + break; + } } - } - if (this.#kernelStore.runQueueLength() === 0) { - const { promise, resolve } = makePromiseKit(); - if (this.#wakeUpTheRunQueue !== null) { - Fail`wakeUpTheRunQueue function already set`; + if (this.#kernelStore.runQueueLength() === 0) { + const { promise, resolve } = makePromiseKit(); + if (this.#wakeUpTheRunQueue !== null) { + Fail`wakeUpTheRunQueue function already set`; + } + this.#wakeUpTheRunQueue = resolve; + await promise; } - this.#wakeUpTheRunQueue = resolve; - await promise; + } finally { + this.#kernelStore.endCrank(); } } } diff --git a/packages/ocap-kernel/src/KernelRouter.test.ts b/packages/ocap-kernel/src/KernelRouter.test.ts index 3216dec40..2d3fc63dc 100644 --- a/packages/ocap-kernel/src/KernelRouter.test.ts +++ b/packages/ocap-kernel/src/KernelRouter.test.ts @@ -13,6 +13,7 @@ import type { RunQueueItemBringOutYourDead, VatId, GCRunQueueType, + CrankResults, } from './types.ts'; import type { VatHandle } from './VatHandle.ts'; @@ -31,14 +32,16 @@ describe('KernelRouter', () => { let kernelRouter: KernelRouter; beforeEach(() => { - // Mock VatHandle + // Mock VatHandle with more detailed return values + const mockCrankResults: CrankResults = { didDelivery: 'v1' }; + vatHandle = { - deliverMessage: vi.fn().mockResolvedValue(undefined), - deliverNotify: vi.fn().mockResolvedValue(undefined), - deliverDropExports: vi.fn().mockResolvedValue(undefined), - deliverRetireExports: vi.fn().mockResolvedValue(undefined), - deliverRetireImports: vi.fn().mockResolvedValue(undefined), - deliverBringOutYourDead: vi.fn().mockResolvedValue(undefined), + deliverMessage: vi.fn().mockResolvedValue(mockCrankResults), + deliverNotify: vi.fn().mockResolvedValue(mockCrankResults), + deliverDropExports: vi.fn().mockResolvedValue(mockCrankResults), + deliverRetireExports: vi.fn().mockResolvedValue(mockCrankResults), + deliverRetireImports: vi.fn().mockResolvedValue(mockCrankResults), + deliverBringOutYourDead: vi.fn().mockResolvedValue(mockCrankResults), } as unknown as VatHandle; // Mock getVat function @@ -65,6 +68,7 @@ describe('KernelRouter', () => { krefsToExistingErefs: vi.fn((_vatId: string, krefs: string[]) => krefs.map((kref: string) => `translated-${kref}`), ) as unknown as MockInstance, + createCrankSavepoint: vi.fn(), } as unknown as KernelStore; // Mock KernelQueue @@ -78,16 +82,23 @@ describe('KernelRouter', () => { describe('deliver', () => { describe('send', () => { - it('delivers a send message to a vat with an object target', async () => { + it('delivers a send message to a vat with an object target and returns crank results', async () => { // Setup the kernel store to return an owner for the target const vatId = 'v1'; const target = 'ko123'; (kernelStore.getOwner as unknown as MockInstance).mockReturnValueOnce( vatId, ); - (kernelStore.erefToKref as unknown as MockInstance).mockReturnValueOnce( - 'not-the-target', - ); + + // Create a mock crank result that the vat will return + const mockCrankResults: CrankResults = { + didDelivery: vatId, + abort: false, + }; + ( + vatHandle.deliverMessage as unknown as MockInstance + ).mockResolvedValueOnce(mockCrankResults); + // Create a send message const message: Message = { methargs: { body: 'method args', slots: ['slot1', 'slot2'] }, @@ -98,13 +109,16 @@ describe('KernelRouter', () => { target, message: message as unknown as SwingsetMessage, }; - await kernelRouter.deliver(sendItem); - // Verify the message was delivered to the vat + + const result = await kernelRouter.deliver(sendItem); + + // Verify the message was delivered to the vat and results returned expect(getVat).toHaveBeenCalledWith(vatId); expect(vatHandle.deliverMessage).toHaveBeenCalledWith( `translated-${target}`, message, ); + expect(result).toStrictEqual(mockCrankResults); expect(kernelStore.decrementRefCount).toHaveBeenCalledWith( 'slot1', 'deliver|send|slot', @@ -123,7 +137,7 @@ describe('KernelRouter', () => { ); }); - it('splats a message when target has no owner', async () => { + it('splats a message when target has no owner and returns undefined', async () => { // Setup the kernel store to return no owner for the target (kernelStore.getOwner as unknown as MockInstance).mockReturnValueOnce( null, @@ -140,10 +154,13 @@ describe('KernelRouter', () => { target, message: message as unknown as SwingsetMessage, }; - await kernelRouter.deliver(sendItem); + const result = await kernelRouter.deliver(sendItem); + // Verify the message was not delivered to any vat and resources were cleaned up expect(getVat).not.toHaveBeenCalled(); expect(vatHandle.deliverMessage).not.toHaveBeenCalled(); + expect(result).toBeUndefined(); + // Verify refcounts were decremented expect(kernelStore.decrementRefCount).toHaveBeenCalledWith( target, @@ -170,13 +187,14 @@ describe('KernelRouter', () => { ); }); - it('enqueues a message on an unresolved promise', async () => { + it('enqueues a message on an unresolved promise and returns undefined', async () => { // Setup a promise reference and unresolved promise in the kernel store const target = 'kp123'; ( kernelStore.getKernelPromise as unknown as MockInstance ).mockReturnValueOnce({ state: 'unresolved', + value: { body: JSON.stringify({ status: 'unresolved' }), slots: [] }, }); // Create a send message const message: Message = { @@ -188,19 +206,108 @@ describe('KernelRouter', () => { target, message: message as unknown as SwingsetMessage, }; - await kernelRouter.deliver(sendItem); + const result = await kernelRouter.deliver(sendItem); + // Verify the message was enqueued on the promise expect(kernelStore.enqueuePromiseMessage).toHaveBeenCalledWith( target, message, ); + // Verify no vat interaction occurred + expect(getVat).not.toHaveBeenCalled(); + expect(vatHandle.deliverMessage).not.toHaveBeenCalled(); + expect(result).toBeUndefined(); + + // Verify that no refcount decrementation happened since we're requeuing + expect(kernelStore.decrementRefCount).not.toHaveBeenCalled(); + }); + + it('splats message when promise resolves to a non-object', async () => { + // Setup a fulfilled promise that doesn't resolve to an object + const promiseId = 'kp123'; + + ( + kernelStore.getKernelPromise as unknown as MockInstance + ).mockReturnValueOnce({ + state: 'fulfilled', + value: { + body: JSON.stringify({ value: 'not an object' }), + slots: [], + }, + }); + + // Create a send message to the promise + const message: Message = { + methargs: { body: 'method args', slots: [] }, + result: 'kp2', + }; + const sendItem: RunQueueItemSend = { + type: 'send', + target: promiseId, + message: message as unknown as SwingsetMessage, + }; + + const result = await kernelRouter.deliver(sendItem); + + // Message should be splatted, not delivered + expect(getVat).not.toHaveBeenCalled(); + expect(vatHandle.deliverMessage).not.toHaveBeenCalled(); + expect(result).toBeUndefined(); + + // Verify the result promise was rejected + expect(kernelQueue.resolvePromises).toHaveBeenCalledWith( + undefined, + expect.arrayContaining([ + expect.arrayContaining(['kp2', true, expect.anything()]), + ]), + ); + }); + + it('splats message when promise is rejected', async () => { + // Setup a rejected promise + const promiseId = 'kp123'; + const rejection = { + body: JSON.stringify({ error: 'rejection reason' }), + slots: [], + }; + + ( + kernelStore.getKernelPromise as unknown as MockInstance + ).mockReturnValueOnce({ + state: 'rejected', + value: rejection, + }); + + // Create a send message to the promise + const message: Message = { + methargs: { body: 'method args', slots: [] }, + result: 'kp2', + }; + const sendItem: RunQueueItemSend = { + type: 'send', + target: promiseId, + message: message as unknown as SwingsetMessage, + }; + + const result = await kernelRouter.deliver(sendItem); + + // Message should be splatted, not delivered expect(getVat).not.toHaveBeenCalled(); expect(vatHandle.deliverMessage).not.toHaveBeenCalled(); + expect(result).toBeUndefined(); + + // Verify the result promise was rejected with the same reason + expect(kernelQueue.resolvePromises).toHaveBeenCalledWith( + undefined, + expect.arrayContaining([ + expect.arrayContaining(['kp2', true, rejection]), + ]), + ); }); }); describe('notify', () => { - it('delivers a notify to a vat', async () => { + it('delivers a notify to a vat and returns crank results', async () => { const vatId = 'v1'; const kpid = 'kp123'; const notifyItem: RunQueueItemNotify = { @@ -208,30 +315,48 @@ describe('KernelRouter', () => { vatId, kpid, }; + // Mock a resolved promise ( kernelStore.getKernelPromise as unknown as MockInstance ).mockReturnValueOnce({ state: 'fulfilled', - value: { body: 'resolved value', slots: [] }, + value: { + body: JSON.stringify({ value: 'resolved value' }), + slots: [], + }, }); + // Mock that this promise is in the vat's clist (kernelStore.krefToEref as unknown as MockInstance).mockReturnValueOnce( 'p+123', ); + // Mock that there's a promise to retire ( kernelStore.getKpidsToRetire as unknown as MockInstance ).mockReturnValueOnce([kpid]); + // Mock the getKernelPromise for the target promise ( kernelStore.getKernelPromise as unknown as MockInstance ).mockReturnValueOnce({ state: 'fulfilled', - value: { body: 'target promise value', slots: [] }, + value: { + body: JSON.stringify({ value: 'target promise value' }), + slots: [], + }, }); + + // Mock crank results + const mockCrankResults: CrankResults = { didDelivery: vatId }; + ( + vatHandle.deliverNotify as unknown as MockInstance + ).mockResolvedValueOnce(mockCrankResults); + // Deliver the notify - await kernelRouter.deliver(notifyItem); + const result = await kernelRouter.deliver(notifyItem); + // Verify the notification was delivered to the vat expect(getVat).toHaveBeenCalledWith(vatId); expect(vatHandle.deliverNotify).toHaveBeenCalledWith(expect.any(Array)); @@ -239,9 +364,10 @@ describe('KernelRouter', () => { kpid, 'deliver|notify', ); + expect(result).toStrictEqual(mockCrankResults); }); - it('does nothing if the promise is not in vat clist', async () => { + it('returns didDelivery when promise is not in vat clist', async () => { const vatId = 'v1'; const kpid = 'kp123'; const notifyItem: RunQueueItemNotify = { @@ -249,21 +375,90 @@ describe('KernelRouter', () => { vatId, kpid, }; + // Mock a resolved promise ( kernelStore.getKernelPromise as unknown as MockInstance ).mockReturnValueOnce({ state: 'fulfilled', - value: { body: 'resolved value', slots: [] }, + value: { + body: JSON.stringify({ value: 'resolved value' }), + slots: [], + }, }); + // Mock that this promise is NOT in the vat's clist (kernelStore.krefToEref as unknown as MockInstance).mockReturnValueOnce( null, ); + + // Deliver the notify + const result = await kernelRouter.deliver(notifyItem); + + // Verify no notification was delivered to the vat + expect(vatHandle.deliverNotify).not.toHaveBeenCalled(); + expect(result).toStrictEqual({ didDelivery: vatId }); + }); + + it('returns didDelivery when no kpids to retire', async () => { + const vatId = 'v1'; + const kpid = 'kp123'; + const notifyItem: RunQueueItemNotify = { + type: 'notify', + vatId, + kpid, + }; + + // Mock a resolved promise + ( + kernelStore.getKernelPromise as unknown as MockInstance + ).mockReturnValueOnce({ + state: 'fulfilled', + value: { + body: JSON.stringify({ value: 'resolved value' }), + slots: [], + }, + }); + + // Mock that this promise is in the vat's clist + (kernelStore.krefToEref as unknown as MockInstance).mockReturnValueOnce( + 'p+123', + ); + + // Mock that there are no promises to retire + ( + kernelStore.getKpidsToRetire as unknown as MockInstance + ).mockReturnValueOnce([]); + // Deliver the notify - await kernelRouter.deliver(notifyItem); + const result = await kernelRouter.deliver(notifyItem); + // Verify no notification was delivered to the vat expect(vatHandle.deliverNotify).not.toHaveBeenCalled(); + expect(result).toStrictEqual({ didDelivery: vatId }); + }); + + it('throws if notification is for an unresolved promise', async () => { + const vatId = 'v1'; + const kpid = 'kp123'; + const notifyItem: RunQueueItemNotify = { + type: 'notify', + vatId, + kpid, + }; + + // Mock an unresolved promise with no value + ( + kernelStore.getKernelPromise as unknown as MockInstance + ).mockReturnValueOnce({ + state: 'unresolved', + value: null, + }); + + // Deliver the notify should throw with the expected error message + await expect(kernelRouter.deliver(notifyItem)).rejects.toThrow( + 'no value for promise kp123', + ); }); }); @@ -272,36 +467,59 @@ describe('KernelRouter', () => { ['dropExports', 'deliverDropExports'], ['retireExports', 'deliverRetireExports'], ['retireImports', 'deliverRetireImports'], - ])('delivers %s to a vat', async (actionType, deliverMethod) => { - const vatId = 'v1'; - const krefs = ['ko1', 'ko2']; - const gcAction: RunQueueItemGCAction = { - type: actionType as GCRunQueueType, - vatId, - krefs, - }; - // Deliver the GC action - await kernelRouter.deliver(gcAction); - // Verify the action was delivered to the vat - expect(getVat).toHaveBeenCalledWith(vatId); - expect( - vatHandle[deliverMethod as keyof VatHandle], - ).toHaveBeenCalledWith(krefs.map((kref) => `translated-${kref}`)); - }); + ])( + 'delivers %s to a vat and returns crank results', + async (actionType, deliverMethod) => { + const vatId = 'v1'; + const krefs = ['ko1', 'ko2']; + const gcAction: RunQueueItemGCAction = { + type: actionType as GCRunQueueType, + vatId, + krefs, + }; + + // Mock crank results + const mockCrankResults: CrankResults = { didDelivery: vatId }; + ( + vatHandle[ + deliverMethod as keyof VatHandle + ] as unknown as MockInstance + ).mockResolvedValueOnce(mockCrankResults); + + // Deliver the GC action + const result = await kernelRouter.deliver(gcAction); + + // Verify the action was delivered to the vat + expect(getVat).toHaveBeenCalledWith(vatId); + expect( + vatHandle[deliverMethod as keyof VatHandle], + ).toHaveBeenCalledWith(krefs.map((kref) => `translated-${kref}`)); + expect(result).toStrictEqual(mockCrankResults); + }, + ); }); describe('bringOutYourDead', () => { - it('delivers bringOutYourDead to a vat', async () => { + it('delivers bringOutYourDead to a vat and returns crank results', async () => { const vatId = 'v1'; const bringOutYourDeadItem: RunQueueItemBringOutYourDead = { type: 'bringOutYourDead', vatId, }; + + // Mock crank results + const mockCrankResults: CrankResults = { didDelivery: vatId }; + ( + vatHandle.deliverBringOutYourDead as unknown as MockInstance + ).mockResolvedValueOnce(mockCrankResults); + // Deliver the bringOutYourDead action - await kernelRouter.deliver(bringOutYourDeadItem); + const result = await kernelRouter.deliver(bringOutYourDeadItem); + // Verify the action was delivered to the vat expect(getVat).toHaveBeenCalledWith(vatId); expect(vatHandle.deliverBringOutYourDead).toHaveBeenCalled(); + expect(result).toStrictEqual(mockCrankResults); }); }); @@ -309,7 +527,7 @@ describe('KernelRouter', () => { // @ts-expect-error - deliberately using an invalid type const invalidItem: RunQueueItem = { type: 'invalid' }; await expect(kernelRouter.deliver(invalidItem)).rejects.toThrow( - 'unknown run queue item type', + 'unsupported or unknown run queue item type', ); }); }); diff --git a/packages/ocap-kernel/src/KernelRouter.ts b/packages/ocap-kernel/src/KernelRouter.ts index 9be7c723b..b0f4d9180 100644 --- a/packages/ocap-kernel/src/KernelRouter.ts +++ b/packages/ocap-kernel/src/KernelRouter.ts @@ -15,6 +15,7 @@ import type { RunQueueItemBringOutYourDead, RunQueueItemNotify, RunQueueItemGCAction, + CrankResults, } from './types.ts'; import { insistVatId, insistMessage } from './types.ts'; import { assert, Fail } from './utils/assert.ts'; @@ -74,27 +75,25 @@ export class KernelRouter { * is also forwarded to all of the promise's registered subscribers. * * @param item - The message/notification to deliver. + * @returns The crank outcome. */ - async deliver(item: RunQueueItem): Promise { + async deliver(item: RunQueueItem): Promise { switch (item.type) { case 'send': - await this.#deliverSend(item); - break; + return await this.#deliverSend(item); case 'notify': - await this.#deliverNotify(item); - break; + return await this.#deliverNotify(item); case 'dropExports': case 'retireExports': case 'retireImports': - await this.#deliverGCAction(item); - break; + return await this.#deliverGCAction(item); case 'bringOutYourDead': - await this.#deliverBringOutYourDead(item); - break; + return await this.#deliverBringOutYourDead(item); default: // @ts-expect-error Runtime does not respect "never". Fail`unsupported or unknown run queue item type ${item.type}`; } + return undefined; } /** @@ -167,9 +166,13 @@ export class KernelRouter { * Deliver a 'send' run queue item. * * @param item - The send item to deliver. + * @returns The crank outcome. */ - async #deliverSend(item: RunQueueItemSend): Promise { + async #deliverSend( + item: RunQueueItemSend, + ): Promise { const route = this.#routeMessage(item); + let crankResults: CrankResults | undefined; // Message went splat if (!route) { @@ -186,7 +189,7 @@ export class KernelRouter { console.log( `@@@@ message went splat ${item.target}<-${JSON.stringify(item.message)}`, ); - return; + return crankResults; } const { vatId, target } = route; @@ -216,7 +219,7 @@ export class KernelRouter { vatId, message, ); - await vat.deliverMessage(vatTarget, vatMessage); + crankResults = await vat.deliverMessage(vatTarget, vatMessage); this.#kernelStore.decrementRefCount(target, 'deliver|send|target'); for (const slot of message.methargs.slots) { this.#kernelStore.decrementRefCount(slot, 'deliver|send|slot'); @@ -230,14 +233,17 @@ export class KernelRouter { console.log( `@@@@ done ${vatId} send ${target}<-${JSON.stringify(message)}`, ); + + return crankResults; } /** * Deliver a 'notify' run queue item. * * @param item - The notify item to deliver. + * @returns The crank outcome. */ - async #deliverNotify(item: RunQueueItemNotify): Promise { + async #deliverNotify(item: RunQueueItemNotify): Promise { const { vatId, kpid } = item; insistVatId(vatId); const { context, isPromise } = parseRef(kpid); @@ -254,12 +260,12 @@ export class KernelRouter { } if (!this.#kernelStore.krefToEref(vatId, kpid)) { // no c-list entry, already done - return; + return { didDelivery: vatId }; } const targets = this.#kernelStore.getKpidsToRetire(kpid, value); if (targets.length === 0) { // no kpids to retire, already done - return; + return { didDelivery: vatId }; } const resolutions: VatOneResolution[] = []; for (const toResolve of targets) { @@ -281,18 +287,20 @@ export class KernelRouter { } } const vat = this.#getVat(vatId); - await vat.deliverNotify(resolutions); + const crankResults = await vat.deliverNotify(resolutions); // Decrement reference count for processed 'notify' item this.#kernelStore.decrementRefCount(kpid, 'deliver|notify'); console.log(`@@@@ done ${vatId} notify ${vatId} ${kpid}`); + return crankResults; } /** * Deliver a Garbage Collection action run queue item. * * @param item - The dropExports | retireExports | retireImports item to deliver. + * @returns The crank outcome. */ - async #deliverGCAction(item: RunQueueItemGCAction): Promise { + async #deliverGCAction(item: RunQueueItemGCAction): Promise { const { type, vatId, krefs } = item; console.log(`@@@@ deliver ${vatId} ${type}`, krefs); const vat = this.#getVat(vatId); @@ -302,22 +310,25 @@ export class KernelRouter { | 'deliverDropExports' | 'deliverRetireExports' | 'deliverRetireImports'; - await vat[method](vrefs); + const crankResults = await vat[method](vrefs); console.log(`@@@@ done ${vatId} ${type}`, krefs); + return crankResults; } /** * Deliver a 'bringOutYourDead' run queue item. * * @param item - The bringOutYourDead item to deliver. + * @returns The crank outcome. */ async #deliverBringOutYourDead( item: RunQueueItemBringOutYourDead, - ): Promise { + ): Promise { const { vatId } = item; console.log(`@@@@ deliver ${vatId} bringOutYourDead`); const vat = this.#getVat(vatId); - await vat.deliverBringOutYourDead(); + const crankResults = await vat.deliverBringOutYourDead(); console.log(`@@@@ done ${vatId} bringOutYourDead`); + return crankResults; } } diff --git a/packages/ocap-kernel/src/VatHandle.test.ts b/packages/ocap-kernel/src/VatHandle.test.ts index d87fed225..342c890aa 100644 --- a/packages/ocap-kernel/src/VatHandle.test.ts +++ b/packages/ocap-kernel/src/VatHandle.test.ts @@ -1,5 +1,4 @@ import type { VatOneResolution } from '@agoric/swingset-liveslots'; -import type { VatCheckpoint } from '@metamask/kernel-store'; import type { JsonRpcMessage } from '@metamask/kernel-utils'; import { isJsonRpcMessage } from '@metamask/kernel-utils'; import type { Logger } from '@metamask/logger'; @@ -12,7 +11,7 @@ import type { MockInstance } from 'vitest'; import type { KernelQueue } from './KernelQueue.ts'; import { makeKernelStore } from './store/index.ts'; import type { KernelStore } from './store/index.ts'; -import type { VRef, Message } from './types.ts'; +import type { VRef, Message, VatDeliveryResult } from './types.ts'; import { VatHandle } from './VatHandle.ts'; import { makeMapKernelDatabase } from '../test/storage.ts'; @@ -123,8 +122,8 @@ describe('VatHandle', () => { it('calls sendVatCommand with the correct method and params', async () => { const { vat } = await makeVat(); sendVatCommandMock.mockReset(); - const mockCheckpoint: VatCheckpoint = [[], []]; - sendVatCommandMock.mockResolvedValueOnce(mockCheckpoint); + const mockResult: VatDeliveryResult = [[[], []], null]; + sendVatCommandMock.mockResolvedValueOnce(mockResult); const target = 'kp1' as VRef; const message: Message = { methargs: { body: '["arg1","arg2"]', slots: [] }, @@ -143,8 +142,8 @@ describe('VatHandle', () => { it('calls sendVatCommand with the correct method and params', async () => { const { vat } = await makeVat(); sendVatCommandMock.mockReset(); - const mockCheckpoint: VatCheckpoint = [[], []]; - sendVatCommandMock.mockResolvedValueOnce(mockCheckpoint); + const mockResult: VatDeliveryResult = [[[], []], null]; + sendVatCommandMock.mockResolvedValueOnce(mockResult); const resolutions: VatOneResolution[] = [ ['vp123', false, { body: '"resolved value"', slots: [] }], ]; @@ -161,8 +160,8 @@ describe('VatHandle', () => { it('calls sendVatCommand with the correct method and params', async () => { const { vat } = await makeVat(); sendVatCommandMock.mockReset(); - const mockCheckpoint: VatCheckpoint = [[], []]; - sendVatCommandMock.mockResolvedValueOnce(mockCheckpoint); + const mockResult: VatDeliveryResult = [[[], []], null]; + sendVatCommandMock.mockResolvedValueOnce(mockResult); const vrefs: VRef[] = ['kp123', 'kp456']; await vat.deliverDropExports(vrefs); expect(sendVatCommandMock).toHaveBeenCalledTimes(1); @@ -177,8 +176,8 @@ describe('VatHandle', () => { it('calls sendVatCommand with the correct method and params', async () => { const { vat } = await makeVat(); sendVatCommandMock.mockReset(); - const mockCheckpoint: VatCheckpoint = [[], []]; - sendVatCommandMock.mockResolvedValueOnce(mockCheckpoint); + const mockResult: VatDeliveryResult = [[[], []], null]; + sendVatCommandMock.mockResolvedValueOnce(mockResult); const vrefs: VRef[] = ['kp123', 'kp456']; await vat.deliverRetireExports(vrefs); expect(sendVatCommandMock).toHaveBeenCalledTimes(1); @@ -193,8 +192,8 @@ describe('VatHandle', () => { it('calls sendVatCommand with the correct method and params', async () => { const { vat } = await makeVat(); sendVatCommandMock.mockReset(); - const mockCheckpoint: VatCheckpoint = [[], []]; - sendVatCommandMock.mockResolvedValueOnce(mockCheckpoint); + const mockResult: VatDeliveryResult = [[[], []], null]; + sendVatCommandMock.mockResolvedValueOnce(mockResult); const vrefs: VRef[] = ['kp123', 'kp456']; await vat.deliverRetireImports(vrefs); expect(sendVatCommandMock).toHaveBeenCalledTimes(1); @@ -209,8 +208,8 @@ describe('VatHandle', () => { it('calls sendVatCommand with the correct method and params', async () => { const { vat } = await makeVat(); sendVatCommandMock.mockReset(); - const mockCheckpoint: VatCheckpoint = [[], []]; - sendVatCommandMock.mockResolvedValueOnce(mockCheckpoint); + const mockResult: VatDeliveryResult = [[[], []], null]; + sendVatCommandMock.mockResolvedValueOnce(mockResult); await vat.deliverBringOutYourDead(); expect(sendVatCommandMock).toHaveBeenCalledTimes(1); expect(sendVatCommandMock).toHaveBeenCalledWith({ diff --git a/packages/ocap-kernel/src/VatHandle.ts b/packages/ocap-kernel/src/VatHandle.ts index 641133d79..b4bd89228 100644 --- a/packages/ocap-kernel/src/VatHandle.ts +++ b/packages/ocap-kernel/src/VatHandle.ts @@ -8,7 +8,7 @@ import type { ExtractParams, ExtractResult, } from '@metamask/kernel-rpc-methods'; -import type { VatStore, VatCheckpoint } from '@metamask/kernel-store'; +import type { VatStore } from '@metamask/kernel-store'; import type { JsonRpcMessage } from '@metamask/kernel-utils'; import { Logger } from '@metamask/logger'; import { serializeError } from '@metamask/rpc-errors'; @@ -18,9 +18,16 @@ import { isJsonRpcRequest, isJsonRpcResponse } from '@metamask/utils'; import type { KernelQueue } from './KernelQueue.ts'; import { vatMethodSpecs, vatSyscallHandlers } from './rpc/index.ts'; import type { PingVatResult, VatMethod } from './rpc/index.ts'; -import { kser } from './services/kernel-marshal.ts'; +import { kser, makeError } from './services/kernel-marshal.ts'; import type { KernelStore } from './store/index.ts'; -import type { Message, VatId, VatConfig, VRef } from './types.ts'; +import type { + Message, + VatId, + VatConfig, + VRef, + CrankResults, + VatDeliveryResult, +} from './types.ts'; import { VatSyscall } from './VatSyscall.ts'; type VatConstructorProps = { @@ -104,8 +111,7 @@ export class VatHandle { ); this.#rpcService = new RpcService(vatSyscallHandlers, { handleSyscall: async (params) => { - await this.#vatSyscall.handleSyscall(params as VatSyscallObject); - return ['ok', null]; // XXX TODO: Return actual results from syscalls + return this.#vatSyscall.handleSyscall(params as VatSyscallObject); }, }); } @@ -202,70 +208,83 @@ export class VatHandle { * * @param target - The VRef of the object to which the message is addressed. * @param message - The message to deliver. + * @returns The crank results. */ - async deliverMessage(target: VRef, message: Message): Promise { + async deliverMessage(target: VRef, message: Message): Promise { await this.sendVatCommand({ method: 'deliver', params: ['message', target, message], }); + return this.#getDeliveryCrankResults(); } /** * Make a 'notify' delivery to the vat. * * @param resolutions - One or more promise resolutions to deliver. + * @returns The crank results. */ - async deliverNotify(resolutions: VatOneResolution[]): Promise { + async deliverNotify(resolutions: VatOneResolution[]): Promise { await this.sendVatCommand({ method: 'deliver', params: ['notify', resolutions], }); + return this.#getDeliveryCrankResults(); } /** * Make a 'dropExports' delivery to the vat. * * @param vrefs - The VRefs of the exports to be dropped. + * @returns The crank results. */ - async deliverDropExports(vrefs: VRef[]): Promise { + async deliverDropExports(vrefs: VRef[]): Promise { await this.sendVatCommand({ method: 'deliver', params: ['dropExports', vrefs], }); + return this.#getDeliveryCrankResults(); } /** * Make a 'retireExports' delivery to the vat. * * @param vrefs - The VRefs of the exports to be retired. + * @returns The crank results. */ - async deliverRetireExports(vrefs: VRef[]): Promise { + async deliverRetireExports(vrefs: VRef[]): Promise { await this.sendVatCommand({ method: 'deliver', params: ['retireExports', vrefs], }); + return this.#getDeliveryCrankResults(); } /** * Make a 'retireImports' delivery to the vat. * * @param vrefs - The VRefs of the imports to be retired. + * @returns The crank results. */ - async deliverRetireImports(vrefs: VRef[]): Promise { + async deliverRetireImports(vrefs: VRef[]): Promise { await this.sendVatCommand({ method: 'deliver', params: ['retireImports', vrefs], }); + return this.#getDeliveryCrankResults(); } /** * Make a 'bringOutYourDead' delivery to the vat. + * + * @returns The crank results. */ - async deliverBringOutYourDead(): Promise { + async deliverBringOutYourDead(): Promise { await this.sendVatCommand({ method: 'deliver', params: ['bringOutYourDead'], }); + return this.#getDeliveryCrankResults(); } /** @@ -277,15 +296,14 @@ export class VatHandle { */ async terminate(terminating: boolean, error?: Error): Promise { await this.#vatStream.end(error); - + const terminationError = error ?? new VatDeletedError(this.vatId); if (terminating) { // Reject promises exported to other vats for which this vat is the decider - const failure = kser(new VatDeletedError(this.vatId)); + const failure = kser(terminationError); for (const kpid of this.#kernelStore.getPromisesByDecider(this.vatId)) { this.#kernelQueue.resolvePromises(this.vatId, [[kpid, true, failure]]); } - - this.#rpcClient.rejectAll(error ?? new VatDeletedError(this.vatId)); + this.#rpcClient.rejectAll(terminationError); this.#kernelStore.deleteVat(this.vatId); } } @@ -306,11 +324,56 @@ export class VatHandle { params: ExtractParams; }): Promise> { const result = await this.#rpcClient.call(method, params); - if (method === 'deliver' || method === 'initVat') { - // TypeScript fails to narrow the result type on its own - const [sets, deletes] = result as VatCheckpoint; - this.#vatStore.updateKVData(sets, deletes); + if (method === 'initVat' || method === 'deliver') { + const [[sets, deletes], deliveryError] = result as VatDeliveryResult; + this.#vatSyscall.deliveryError = deliveryError ?? undefined; + const noErrors = !deliveryError && !this.#vatSyscall.illegalSyscall; + // On errors, we neither update this vat's KV data nor rollback previous changes. + // This is safe because vats are always terminated when errors occur + // and they have their own databases, which are deleted when the vat is terminated. + // The main kernel database will be rolled back. + if (noErrors) { + this.#vatStore.updateKVData(sets, deletes); + } } return result; } + + /** + * Get the crank outcome for a given checkpoint result. + * + * @returns The crank outcome. + */ + async #getDeliveryCrankResults(): Promise { + await this.#vatSyscall.waitForSyscallsToComplete(); + + const results: CrankResults = { + didDelivery: this.vatId, + }; + + // These conditionals express a priority order: the consequences of an + // illegal syscall take precedence over a vat requesting termination, etc. + if (this.#vatSyscall.illegalSyscall) { + results.abort = true; + const { info } = this.#vatSyscall.illegalSyscall; + // TODO: For now, vat errors both rewind changes and terminate the vat. + // Some day, they might rewind changes and retry the syscall. + // We should terminate the vat only after a certain # of failed retries. + results.terminate = { vatId: this.vatId, reject: true, info }; + } else if (this.#vatSyscall.deliveryError) { + results.abort = true; + const info = makeError(this.#vatSyscall.deliveryError); + results.terminate = { vatId: this.vatId, reject: true, info }; + } else if (this.#vatSyscall.vatRequestedTermination) { + if (this.#vatSyscall.vatRequestedTermination.reject) { + results.abort = true; // vatPowers.exitWithFailure wants rewind + } + results.terminate = { + vatId: this.vatId, + ...this.#vatSyscall.vatRequestedTermination, + }; + } + + return harden(results); + } } diff --git a/packages/ocap-kernel/src/VatSupervisor.ts b/packages/ocap-kernel/src/VatSupervisor.ts index 59a65cc29..440890f6d 100644 --- a/packages/ocap-kernel/src/VatSupervisor.ts +++ b/packages/ocap-kernel/src/VatSupervisor.ts @@ -9,7 +9,7 @@ import { makeMarshal } from '@endo/marshal'; import type { CapData } from '@endo/marshal'; import { StreamReadError } from '@metamask/kernel-errors'; import { RpcClient, RpcService } from '@metamask/kernel-rpc-methods'; -import type { VatKVStore, VatCheckpoint } from '@metamask/kernel-store'; +import type { VatKVStore } from '@metamask/kernel-store'; import { waitUntilQuiescent } from '@metamask/kernel-utils'; import type { JsonRpcMessage } from '@metamask/kernel-utils'; import type { Logger } from '@metamask/logger'; @@ -18,13 +18,12 @@ import type { DuplexStream } from '@metamask/streams'; import { isJsonRpcRequest, isJsonRpcResponse } from '@metamask/utils'; import { vatSyscallMethodSpecs, vatHandlers } from './rpc/index.ts'; -import type { InitVat } from './rpc/vat/initVat.ts'; import { makeGCAndFinalize } from './services/gc-finalize.ts'; import { makeDummyMeterControl } from './services/meter-control.ts'; import { makeSupervisorSyscall } from './services/syscall.ts'; import type { DispatchFn, MakeLiveSlotsFn, GCTools } from './services/types.ts'; import { makeVatKVStore } from './store/vat-kv-store.ts'; -import type { VatId } from './types.ts'; +import type { VatConfig, VatDeliveryResult, VatId } from './types.ts'; import { isVatConfig, coerceVatSyscallObject } from './types.ts'; const makeLiveSlots: MakeLiveSlotsFn = localMakeLiveSlots; @@ -179,26 +178,43 @@ export class VatSupervisor { */ executeSyscall(vso: VatSyscallObject): VatSyscallResult { this.#syscallsInFlight.push( - // XXX TODO: These all get rejected, so we have to catch them. See #deliver. + // IMPORTANT: Syscall architecture design explanation: + // - Vats operate on an "optimistic execution" model - they send syscalls and continue execution + // without waiting for responses, assuming success. + // - The Kernel processes syscalls asynchronously and failures are catched in VatHandle. this.#rpcClient .call('syscall', coerceVatSyscallObject(vso)) + // We catch these rejections here to prevent unhandled promise rejections, + // as they're an expected part of the architecture, not errors .catch(() => undefined), ); return ['ok', null]; } - async #deliver(params: VatDeliveryObject): Promise { + async #deliver(params: VatDeliveryObject): Promise { if (!this.#dispatch) { throw new Error(`cannot deliver before vat is loaded`); } - await this.#dispatch(harden(params)); - // XXX TODO: Actually handle the syscall results - this.#syscallsInFlight.length = 0; - this.#rpcClient.rejectAll(new Error('end of crank')); + let deliveryError: string | null = null; + + try { + await this.#dispatch(harden(params)); + } catch (error) { + // Handle delivery errors + deliveryError = error instanceof Error ? error.message : String(error); + this.#logger.error(`Delivery error in vat ${this.id}:`, deliveryError); + } finally { + // Clean up at the end of a crank + this.#syscallsInFlight.length = 0; + // Reject all pending RPC requests to maintain the optimistic execution model + // This prevents late responses from affecting the vat in unexpected ways + // between cranks. + this.#rpcClient.rejectAll(new Error('end of crank')); + } // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - return this.#vatKVStore!.checkpoint(); + return [this.#vatKVStore!.checkpoint(), deliveryError]; } /** @@ -210,7 +226,10 @@ export class VatSupervisor { * * @returns a promise for a checkpoint of the new vat. */ - readonly #initVat: InitVat = async (vatConfig, state) => { + async #initVat( + vatConfig: VatConfig, + state: Map, + ): Promise { if (this.#loaded) { throw Error( 'VatSupervisor received initVat after user code already loaded', @@ -275,8 +294,7 @@ export class VatSupervisor { this.#dispatch = liveslots.dispatch; const serParam = marshal.toCapData(harden(parameters)) as CapData; - await this.#dispatch(harden(['startVat', serParam])); - return this.#vatKVStore.checkpoint(); - }; + return await this.#deliver(harden(['startVat', serParam])); + } } diff --git a/packages/ocap-kernel/src/VatSyscall.test.ts b/packages/ocap-kernel/src/VatSyscall.test.ts index 652dc3a5d..81e93a723 100644 --- a/packages/ocap-kernel/src/VatSyscall.test.ts +++ b/packages/ocap-kernel/src/VatSyscall.test.ts @@ -3,6 +3,7 @@ import type { VatOneResolution, VatSyscallObject, } from '@agoric/swingset-liveslots'; +import * as kernelUtils from '@metamask/kernel-utils'; import type { Logger } from '@metamask/logger'; import type { MockInstance } from 'vitest'; import { describe, it, expect, vi, beforeEach } from 'vitest'; @@ -30,8 +31,13 @@ describe('VatSyscall', () => { clearReachableFlag: vi.fn(), getReachableFlag: vi.fn(), forgetKref: vi.fn(), + getVatConfig: vi.fn(() => ({})), + isVatActive: vi.fn(() => true), } as unknown as KernelStore; - logger = { debug: vi.fn() } as unknown as Logger; + logger = { + debug: vi.fn(), + error: vi.fn(), + } as unknown as Logger; vatSys = new VatSyscall({ vatId: 'v1', kernelQueue, kernelStore, logger }); }); @@ -39,14 +45,14 @@ describe('VatSyscall', () => { const target = 'o+1'; const message = {} as unknown as Message; const vso = ['send', target, message] as unknown as VatSyscallObject; - await vatSys.handleSyscall(vso); + vatSys.handleSyscall(vso); expect(kernelQueue.enqueueSend).toHaveBeenCalledWith(target, message); }); it('calls resolvePromises for resolve syscall', async () => { const resolution = ['kp1', false, {}] as unknown as VatOneResolution; const vso = ['resolve', [resolution]] as unknown as VatSyscallObject; - await vatSys.handleSyscall(vso); + vatSys.handleSyscall(vso); expect(kernelQueue.resolvePromises).toHaveBeenCalledWith('v1', [ resolution, ]); @@ -60,7 +66,7 @@ describe('VatSyscall', () => { state: 'unresolved', }); const vso = ['subscribe', 'kp1'] as unknown as VatSyscallObject; - await vatSys.handleSyscall(vso); + vatSys.handleSyscall(vso); expect(kernelStore.addPromiseSubscriber).toHaveBeenCalledWith( 'v1', 'kp1', @@ -74,7 +80,7 @@ describe('VatSyscall', () => { state: 'fulfilled', }); const vso = ['subscribe', 'kp1'] as unknown as VatSyscallObject; - await vatSys.handleSyscall(vso); + vatSys.handleSyscall(vso); expect(kernelQueue.enqueueNotify).toHaveBeenCalledWith('v1', 'kp1'); }); }); @@ -85,7 +91,7 @@ describe('VatSyscall', () => { 'dropImports', ['o-1', 'o-2'], ] as unknown as VatSyscallObject; - await vatSys.handleSyscall(vso); + vatSys.handleSyscall(vso); expect(kernelStore.clearReachableFlag).toHaveBeenCalledWith('v1', 'o-1'); expect(kernelStore.clearReachableFlag).toHaveBeenCalledWith('v1', 'o-2'); }); @@ -93,9 +99,15 @@ describe('VatSyscall', () => { it.each([ ['o+1', 'vat v1 issued invalid syscall dropImports for o+1'], ['p-1', 'vat v1 issued invalid syscall dropImports for p-1'], - ])('throws for invalid ref %s', async (ref, errMsg) => { + ])('returns error for invalid ref %s', async (ref, errMsg) => { + ( + kernelStore.translateSyscallVtoK as unknown as MockInstance + ).mockImplementationOnce(() => { + throw new Error(errMsg); + }); const vso = ['dropImports', [ref]] as unknown as VatSyscallObject; - await expect(vatSys.handleSyscall(vso)).rejects.toThrow(errMsg); + const result = vatSys.handleSyscall(vso); + expect(result).toStrictEqual(['error', errMsg]); }); }); @@ -105,18 +117,25 @@ describe('VatSyscall', () => { kernelStore.getReachableFlag as unknown as MockInstance ).mockReturnValueOnce(false); const vso = ['retireImports', ['o-1']] as unknown as VatSyscallObject; - await vatSys.handleSyscall(vso); + vatSys.handleSyscall(vso); expect(kernelStore.forgetKref).toHaveBeenCalledWith('v1', 'o-1'); }); - it('throws if still reachable', async () => { + it('returns error if still reachable', async () => { ( - kernelStore.getReachableFlag as unknown as MockInstance - ).mockReturnValueOnce(true); + kernelStore.translateSyscallVtoK as unknown as MockInstance + ).mockImplementationOnce(() => { + ( + kernelStore.getReachableFlag as unknown as MockInstance + ).mockReturnValueOnce(true); + throw new Error('syscall.retireImports but o-1 is still reachable'); + }); const vso = ['retireImports', ['o-1']] as unknown as VatSyscallObject; - await expect(vatSys.handleSyscall(vso)).rejects.toThrow( + const result = vatSys.handleSyscall(vso); + expect(result).toStrictEqual([ + 'error', 'syscall.retireImports but o-1 is still reachable', - ); + ]); }); }); @@ -126,36 +145,96 @@ describe('VatSyscall', () => { kernelStore.getReachableFlag as unknown as MockInstance ).mockReturnValueOnce(false); const vso = ['retireExports', ['o+1']] as unknown as VatSyscallObject; - await vatSys.handleSyscall(vso); + vatSys.handleSyscall(vso); expect(kernelStore.forgetKref).toHaveBeenCalledWith('v1', 'o+1'); expect(logger.debug).toHaveBeenCalledWith( 'retireExports: deleted object o+1', ); }); - it('throws for reachable exports', async () => { + it('returns error for reachable exports', async () => { ( - kernelStore.getReachableFlag as unknown as MockInstance - ).mockReturnValueOnce(true); + kernelStore.translateSyscallVtoK as unknown as MockInstance + ).mockImplementationOnce(() => { + ( + kernelStore.getReachableFlag as unknown as MockInstance + ).mockReturnValueOnce(true); + throw new Error('syscall.retireExports but o+1 is still reachable'); + }); const vso = ['retireExports', ['o+1']] as unknown as VatSyscallObject; - await expect(vatSys.handleSyscall(vso)).rejects.toThrow( + const result = vatSys.handleSyscall(vso); + expect(result).toStrictEqual([ + 'error', 'syscall.retireExports but o+1 is still reachable', - ); + ]); }); it('abandons exports without reachability check', async () => { const vso = ['abandonExports', ['o+1']] as unknown as VatSyscallObject; - await vatSys.handleSyscall(vso); + vatSys.handleSyscall(vso); expect(kernelStore.forgetKref).toHaveBeenCalledWith('v1', 'o+1'); expect(logger.debug).toHaveBeenCalledWith( 'abandonExports: deleted object o+1', ); }); - it('throws for invalid abandonExports refs', async () => { + it('returns error for invalid abandonExports refs', async () => { + ( + kernelStore.translateSyscallVtoK as unknown as MockInstance + ).mockImplementationOnce(() => { + throw new Error('vat v1 issued invalid syscall abandonExports for o-1'); + }); const vso = ['abandonExports', ['o-1']] as unknown as VatSyscallObject; - await expect(vatSys.handleSyscall(vso)).rejects.toThrow( + const result = vatSys.handleSyscall(vso); + expect(result).toStrictEqual([ + 'error', 'vat v1 issued invalid syscall abandonExports for o-1', + ]); + }); + }); + + describe('exit syscall', () => { + it('records vat termination request', async () => { + const vso = [ + 'exit', + true, + { message: 'error' }, + ] as unknown as VatSyscallObject; + vatSys.handleSyscall(vso); + expect(vatSys.vatRequestedTermination).toStrictEqual({ + reject: true, + info: { message: 'error' }, + }); + }); + }); + + describe('error handling', () => { + it('handles vat not found error', async () => { + (kernelStore.isVatActive as unknown as MockInstance).mockReturnValueOnce( + false, + ); + const vso = ['send', 'o+1', {}] as unknown as VatSyscallObject; + const result = vatSys.handleSyscall(vso); + + expect(result).toStrictEqual(['error', 'vat not found']); + expect(vatSys.illegalSyscall).toBeDefined(); + }); + + it('handles general syscall errors', async () => { + const error = new Error('test error'); + ( + kernelStore.translateSyscallVtoK as unknown as MockInstance + ).mockImplementationOnce(() => { + throw error; + }); + + const vso = ['send', 'o+1', {}] as unknown as VatSyscallObject; + const result = vatSys.handleSyscall(vso); + + expect(result).toStrictEqual(['error', 'test error']); + expect(logger.error).toHaveBeenCalledWith( + 'Fatal syscall error in vat v1', + error, ); }); }); @@ -163,15 +242,43 @@ describe('VatSyscall', () => { describe('invalid or unknown syscalls', () => { it.each([ ['vatstoreGet', 'invalid syscall vatstoreGet'], + ['vatstoreGetNextKey', 'invalid syscall vatstoreGetNextKey'], + ['vatstoreSet', 'invalid syscall vatstoreSet'], + ['vatstoreDelete', 'invalid syscall vatstoreDelete'], + ['callNow', 'invalid syscall callNow'], ['unknownOp', 'unknown syscall unknownOp'], ])('%s should warn', async (op, message) => { const spy = vi.spyOn(console, 'warn').mockImplementation(() => { // do nothing }); const vso = [op, []] as unknown as VatSyscallObject; - await vatSys.handleSyscall(vso); + vatSys.handleSyscall(vso); expect(spy).toHaveBeenCalledWith(expect.stringContaining(message), vso); spy.mockRestore(); }); }); + + describe('waitForSyscallsToComplete', () => { + it('resolves immediately if pendingSyscalls is zero', async () => { + vatSys.pendingSyscalls = 0; + const delaySpy = vi.spyOn(kernelUtils, 'delay'); + await vatSys.waitForSyscallsToComplete(); + expect(delaySpy).not.toHaveBeenCalled(); + delaySpy.mockRestore(); + }); + + it('waits and resolves when pendingSyscalls becomes zero', async () => { + vatSys.pendingSyscalls = 2; + const delaySpy = vi + .spyOn(kernelUtils, 'delay') + .mockImplementation(async () => { + vatSys.pendingSyscalls -= 1; + return Promise.resolve(); + }); + await vatSys.waitForSyscallsToComplete(); + expect(delaySpy).toHaveBeenCalledTimes(2); + expect(vatSys.pendingSyscalls).toBe(0); + delaySpy.mockRestore(); + }); + }); }); diff --git a/packages/ocap-kernel/src/VatSyscall.ts b/packages/ocap-kernel/src/VatSyscall.ts index c8a4bf407..10954158a 100644 --- a/packages/ocap-kernel/src/VatSyscall.ts +++ b/packages/ocap-kernel/src/VatSyscall.ts @@ -1,10 +1,14 @@ import type { + SwingSetCapData, VatOneResolution, VatSyscallObject, + VatSyscallResult, } from '@agoric/swingset-liveslots'; +import { delay } from '@metamask/kernel-utils'; import { Logger } from '@metamask/logger'; import type { KernelQueue } from './KernelQueue.ts'; +import { makeError } from './services/kernel-marshal.ts'; import type { KernelStore } from './store/index.ts'; import { parseRef } from './store/utils/parse-ref.ts'; import { coerceMessage } from './types.ts'; @@ -36,6 +40,20 @@ export class VatSyscall { /** Logger for outputting messages (such as errors) to the console */ readonly #logger: Logger; + /** The illegal syscall that was received */ + illegalSyscall: { vatId: VatId; info: SwingSetCapData } | undefined; + + /** The error when delivery failed */ + deliveryError: string | undefined; + + /** The termination request that was received from the vat with syscall.exit() */ + vatRequestedTermination: + | { reject: boolean; info: SwingSetCapData } + | undefined; + + /** The pending syscalls that were received from the vat */ + pendingSyscalls: number = 0; + /** * Construct a new VatSyscall instance. * @@ -159,84 +177,135 @@ export class VatSyscall { * Handle a syscall from the vat. * * @param vso - The syscall that was received. + * @returns The result of the syscall. */ - async handleSyscall(vso: VatSyscallObject): Promise { - const kso: VatSyscallObject = this.#kernelStore.translateSyscallVtoK( - this.vatId, - vso, - ); - const [op] = kso; - const { vatId } = this; - const { log } = console; - switch (op) { - case 'send': { - // [KRef, Message]; - const [, target, message] = kso; - log(`@@@@ ${vatId} syscall send ${target}<-${JSON.stringify(message)}`); - this.#handleSyscallSend(target, coerceMessage(message)); - break; - } - case 'subscribe': { - // [KRef]; - const [, promise] = kso; - log(`@@@@ ${vatId} syscall subscribe ${promise}`); - this.#handleSyscallSubscribe(promise); - break; - } - case 'resolve': { - // [VatOneResolution[]]; - const [, resolutions] = kso; - log(`@@@@ ${vatId} syscall resolve ${JSON.stringify(resolutions)}`); - this.#handleSyscallResolve(resolutions as VatOneResolution[]); - break; - } - case 'exit': { - // [boolean, SwingSetCapData]; - const [, fail, info] = kso; - log(`@@@@ ${vatId} syscall exit fail=${fail} ${JSON.stringify(info)}`); - break; - } - case 'dropImports': { - // [KRef[]]; - const [, refs] = kso; - log(`@@@@ ${vatId} syscall dropImports ${JSON.stringify(refs)}`); - this.#handleSyscallDropImports(refs); - break; - } - case 'retireImports': { - // [KRef[]]; - const [, refs] = kso; - log(`@@@@ ${vatId} syscall retireImports ${JSON.stringify(refs)}`); - this.#handleSyscallRetireImports(refs); - break; - } - case 'retireExports': { - // [KRef[]]; - const [, refs] = kso; - log(`@@@@ ${vatId} syscall retireExports ${JSON.stringify(refs)}`); - this.#handleSyscallExportCleanup(refs, true); - break; - } - case 'abandonExports': { - // [KRef[]]; - const [, refs] = kso; - log(`@@@@ ${vatId} syscall abandonExports ${JSON.stringify(refs)}`); - this.#handleSyscallExportCleanup(refs, false); - break; + handleSyscall(vso: VatSyscallObject): VatSyscallResult { + try { + this.illegalSyscall = undefined; + this.vatRequestedTermination = undefined; + this.pendingSyscalls += 1; + + // This is a safety check - this case should never happen + if (!this.#kernelStore.isVatActive(this.vatId)) { + this.#recordVatFatalSyscall('vat not found'); + return harden(['error', 'vat not found']); } - case 'callNow': - case 'vatstoreGet': - case 'vatstoreGetNextKey': - case 'vatstoreSet': - case 'vatstoreDelete': { - console.warn(`vat ${vatId} issued invalid syscall ${op} `, vso); - break; + + const kso: VatSyscallObject = this.#kernelStore.translateSyscallVtoK( + this.vatId, + vso, + ); + const [op] = kso; + const { vatId } = this; + const { log } = console; + switch (op) { + case 'send': { + // [KRef, Message]; + const [, target, message] = kso; + log( + `@@@@ ${vatId} syscall send ${target}<-${JSON.stringify(message)}`, + ); + this.#handleSyscallSend(target, coerceMessage(message)); + break; + } + case 'subscribe': { + // [KRef]; + const [, promise] = kso; + log(`@@@@ ${vatId} syscall subscribe ${promise}`); + this.#handleSyscallSubscribe(promise); + break; + } + case 'resolve': { + // [VatOneResolution[]]; + const [, resolutions] = kso; + log(`@@@@ ${vatId} syscall resolve ${JSON.stringify(resolutions)}`); + this.#handleSyscallResolve(resolutions as VatOneResolution[]); + break; + } + case 'exit': { + // [boolean, SwingSetCapData]; + const [, isFailure, info] = kso; + log( + `@@@@ ${vatId} syscall exit fail=${isFailure} ${JSON.stringify(info)}`, + ); + this.vatRequestedTermination = { reject: isFailure, info }; + break; + } + case 'dropImports': { + // [KRef[]]; + const [, refs] = kso; + log(`@@@@ ${vatId} syscall dropImports ${JSON.stringify(refs)}`); + this.#handleSyscallDropImports(refs); + break; + } + case 'retireImports': { + // [KRef[]]; + const [, refs] = kso; + log(`@@@@ ${vatId} syscall retireImports ${JSON.stringify(refs)}`); + this.#handleSyscallRetireImports(refs); + break; + } + case 'retireExports': { + // [KRef[]]; + const [, refs] = kso; + log(`@@@@ ${vatId} syscall retireExports ${JSON.stringify(refs)}`); + this.#handleSyscallExportCleanup(refs, true); + break; + } + case 'abandonExports': { + // [KRef[]]; + const [, refs] = kso; + log(`@@@@ ${vatId} syscall abandonExports ${JSON.stringify(refs)}`); + this.#handleSyscallExportCleanup(refs, false); + break; + } + case 'callNow': + case 'vatstoreGet': + case 'vatstoreGetNextKey': + case 'vatstoreSet': + case 'vatstoreDelete': { + console.warn(`vat ${vatId} issued invalid syscall ${op} `, vso); + break; + } + default: + // Compile-time exhaustiveness check + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + console.warn(`vat ${vatId} issued unknown syscall ${op} `, vso); + break; } - default: - // Compile-time exhaustiveness check - // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - console.warn(`vat ${vatId} issued unknown syscall ${op} `, vso); - break; + return harden(['ok', null]); + } catch (error) { + this.#logger.error(`Fatal syscall error in vat ${this.vatId}`, error); + this.#recordVatFatalSyscall('syscall translation error: prepare to die'); + return harden([ + 'error', + error instanceof Error ? error.message : String(error), + ]); + } finally { + this.pendingSyscalls -= 1; + } + } + + /** + * Log a fatal syscall error and set the illegalSyscall property. + * + * @param error - The error message to log. + */ + #recordVatFatalSyscall(error: string): void { + this.illegalSyscall = { vatId: this.vatId, info: makeError(error) }; + } + + /** + * Wait for all syscalls to complete. + * + * This is useful because syscalls are asynchronous, + * and we need to wait for them to complete before returning the result. + * + * @returns A promise that resolves when all syscalls have completed. + */ + async waitForSyscallsToComplete(): Promise { + while (this.pendingSyscalls > 0) { + await delay(10); } } } diff --git a/packages/ocap-kernel/src/rpc/vat/deliver.ts b/packages/ocap-kernel/src/rpc/vat/deliver.ts index 1893430ff..d4fcfe83e 100644 --- a/packages/ocap-kernel/src/rpc/vat/deliver.ts +++ b/packages/ocap-kernel/src/rpc/vat/deliver.ts @@ -1,5 +1,4 @@ import type { Handler, MethodSpec } from '@metamask/kernel-rpc-methods'; -import type { VatCheckpoint } from '@metamask/kernel-store'; import { tuple, literal, @@ -11,12 +10,13 @@ import { import type { Infer } from '@metamask/superstruct'; import { UnsafeJsonStruct } from '@metamask/utils'; -import { VatCheckpointStruct } from './shared.ts'; +import { VatDeliveryResultStruct } from './shared.ts'; import { CapDataStruct, MessageStruct, VatOneResolutionStruct, } from '../../types.ts'; +import type { VatDeliveryResult } from '../../types.ts'; const MessageDeliveryStruct = tuple([ literal('message'), @@ -72,18 +72,18 @@ type VatDeliveryParams = Infer; export type DeliverSpec = MethodSpec< 'deliver', VatDeliveryParams, - Promise + Promise >; export const deliverSpec: DeliverSpec = { method: 'deliver', params: VatDeliveryParamsStruct, - result: VatCheckpointStruct, + result: VatDeliveryResultStruct, } as const; export type HandleDelivery = ( params: VatDeliveryParams, -) => Promise; +) => Promise; type DeliverHooks = { handleDelivery: HandleDelivery; @@ -92,7 +92,7 @@ type DeliverHooks = { export type DeliverHandler = Handler< 'deliver', VatDeliveryParams, - Promise, + Promise, DeliverHooks >; diff --git a/packages/ocap-kernel/src/rpc/vat/initVat.ts b/packages/ocap-kernel/src/rpc/vat/initVat.ts index 2e76f02fd..e46a124c7 100644 --- a/packages/ocap-kernel/src/rpc/vat/initVat.ts +++ b/packages/ocap-kernel/src/rpc/vat/initVat.ts @@ -1,11 +1,10 @@ import type { MethodSpec, Handler } from '@metamask/kernel-rpc-methods'; -import type { VatCheckpoint } from '@metamask/kernel-store'; import { array, object, string, tuple } from '@metamask/superstruct'; import type { Infer } from '@metamask/superstruct'; -import { VatCheckpointStruct } from './shared.ts'; +import { VatDeliveryResultStruct } from './shared.ts'; import { VatConfigStruct } from '../../types.ts'; -import type { VatConfig } from '../../types.ts'; +import type { VatConfig, VatDeliveryResult } from '../../types.ts'; const paramsStruct = object({ vatConfig: VatConfigStruct, @@ -14,18 +13,22 @@ const paramsStruct = object({ type Params = Infer; -export type InitVatSpec = MethodSpec<'initVat', Params, Promise>; +export type InitVatSpec = MethodSpec< + 'initVat', + Params, + Promise +>; export const initVatSpec: InitVatSpec = { method: 'initVat', params: paramsStruct, - result: VatCheckpointStruct, + result: VatDeliveryResultStruct, }; export type InitVat = ( vatConfig: VatConfig, state: Map, -) => Promise; +) => Promise; type InitVatHooks = { initVat: InitVat; @@ -34,7 +37,7 @@ type InitVatHooks = { export type InitVatHandler = Handler< 'initVat', Params, - Promise, + Promise, InitVatHooks >; diff --git a/packages/ocap-kernel/src/rpc/vat/shared.ts b/packages/ocap-kernel/src/rpc/vat/shared.ts index a0aa21425..0a4b313ea 100644 --- a/packages/ocap-kernel/src/rpc/vat/shared.ts +++ b/packages/ocap-kernel/src/rpc/vat/shared.ts @@ -1,8 +1,15 @@ import type { VatCheckpoint } from '@metamask/kernel-store'; import type { Struct } from '@metamask/superstruct'; -import { tuple, array, string } from '@metamask/superstruct'; +import { tuple, array, string, union, literal } from '@metamask/superstruct'; + +import type { VatDeliveryResult } from '../../types.ts'; export const VatCheckpointStruct: Struct = tuple([ array(tuple([string(), string()])), array(string()), ]); + +export const VatDeliveryResultStruct: Struct = tuple([ + VatCheckpointStruct, + union([string(), literal(null)]), +]); diff --git a/packages/ocap-kernel/src/store/index.test.ts b/packages/ocap-kernel/src/store/index.test.ts index 25a2be270..2454f8846 100644 --- a/packages/ocap-kernel/src/store/index.test.ts +++ b/packages/ocap-kernel/src/store/index.test.ts @@ -48,6 +48,7 @@ describe('kernel store', () => { 'clear', 'clearReachableFlag', 'collectGarbage', + 'createCrankSavepoint', 'decRefCount', 'decrementRefCount', 'deleteCListEntry', @@ -57,6 +58,7 @@ describe('kernel store', () => { 'deleteVat', 'deleteVatConfig', 'dequeueRun', + 'endCrank', 'enqueuePromiseMessage', 'enqueueRun', 'erefToKref', @@ -95,6 +97,7 @@ describe('kernel store', () => { 'initKernelPromise', 'isObjectPinned', 'isRootObject', + 'isVatActive', 'isVatTerminated', 'kernelRefExists', 'krefToEref', @@ -105,15 +108,18 @@ describe('kernel store', () => { 'nextReapAction', 'nextTerminatedVatCleanup', 'pinObject', + 'releaseAllSavepoints', 'reset', 'resolveKernelPromise', 'retireKernelObjects', + 'rollbackCrank', 'runQueueLength', 'scheduleReap', 'setGCActions', 'setObjectRefCount', 'setPromiseDecider', 'setVatConfig', + 'startCrank', 'translateCapDataKtoV', 'translateMessageKtoV', 'translateRefKtoV', diff --git a/packages/ocap-kernel/src/store/index.ts b/packages/ocap-kernel/src/store/index.ts index 5e01fc2f4..cd83ae48b 100644 --- a/packages/ocap-kernel/src/store/index.ts +++ b/packages/ocap-kernel/src/store/index.ts @@ -60,6 +60,7 @@ import type { KernelDatabase, KVStore, VatStore } from '@metamask/kernel-store'; import type { KRef, VatId } from '../types.ts'; import { getBaseMethods } from './methods/base.ts'; import { getCListMethods } from './methods/clist.ts'; +import { getCrankMethods } from './methods/crank.ts'; import { getGCMethods } from './methods/gc.ts'; import { getIdMethods } from './methods/id.ts'; import { getObjectMethods } from './methods/object.ts'; @@ -124,6 +125,8 @@ export function makeKernelStore(kdb: KernelDatabase) { gcActions: provideCachedStoredValue('gcActions', '[]'), reapQueue: provideCachedStoredValue('reapQueue', '[]'), terminatedVats: provideCachedStoredValue('vats.terminated', '[]'), + inCrank: false, + savepoints: [], }; const id = getIdMethods(context); @@ -137,6 +140,7 @@ export function makeKernelStore(kdb: KernelDatabase) { const reachable = getReachableMethods(context); const translators = getTranslators(context); const pinned = getPinMethods(context); + const crank = getCrankMethods(context, kdb); /** * Create a new VatStore for a vat. @@ -173,6 +177,7 @@ export function makeKernelStore(kdb: KernelDatabase) { context.nextPromiseId = provideCachedStoredValue('nextPromiseId', '1'); context.nextVatId = provideCachedStoredValue('nextVatId', '1'); context.nextRemoteId = provideCachedStoredValue('nextRemoteId', '1'); + crank.releaseAllSavepoints(); } /** @@ -194,6 +199,7 @@ export function makeKernelStore(kdb: KernelDatabase) { ...vat, ...translators, ...pinned, + ...crank, makeVatStore, deleteVat, clear, diff --git a/packages/ocap-kernel/src/store/methods/crank.test.ts b/packages/ocap-kernel/src/store/methods/crank.test.ts new file mode 100644 index 000000000..38bd62130 --- /dev/null +++ b/packages/ocap-kernel/src/store/methods/crank.test.ts @@ -0,0 +1,142 @@ +import type { KernelDatabase } from '@metamask/kernel-store'; +import { expect, describe, it, vi, beforeEach } from 'vitest'; + +import { getCrankMethods } from './crank.ts'; +import type { StoreContext } from '../types.ts'; + +describe('crank methods', () => { + let context: StoreContext; + let kdb: KernelDatabase; + let crankMethods: ReturnType; + + beforeEach(() => { + context = { + inCrank: false, + savepoints: [], + } as unknown as StoreContext; + + kdb = { + createSavepoint: vi.fn(), + rollbackSavepoint: vi.fn(), + releaseSavepoint: vi.fn(), + } as unknown as KernelDatabase; + + crankMethods = getCrankMethods(context, kdb); + }); + + describe('startCrank', () => { + it('should set inCrank to true', () => { + crankMethods.startCrank(); + expect(context.inCrank).toBe(true); + }); + + it('should throw when already in a crank', () => { + context.inCrank = true; + expect(() => crankMethods.startCrank()).toThrow( + 'startCrank while already in a crank', + ); + }); + }); + + describe('createCrankSavepoint', () => { + it('should create a savepoint when in a crank', () => { + context.inCrank = true; + crankMethods.createCrankSavepoint('test'); + + expect(context.savepoints).toStrictEqual(['test']); + expect(kdb.createSavepoint).toHaveBeenCalledWith('t0'); + }); + + it('should create multiple savepoints sequentially', () => { + context.inCrank = true; + crankMethods.createCrankSavepoint('first'); + crankMethods.createCrankSavepoint('second'); + + expect(context.savepoints).toStrictEqual(['first', 'second']); + expect(kdb.createSavepoint).toHaveBeenCalledWith('t0'); + expect(kdb.createSavepoint).toHaveBeenCalledWith('t1'); + }); + + it('should throw when not in a crank', () => { + expect(() => crankMethods.createCrankSavepoint('test')).toThrow( + 'createCrankSavepoint outside of crank', + ); + }); + }); + + describe('rollbackCrank', () => { + it('should rollback to specified savepoint', () => { + context.inCrank = true; + context.savepoints = ['first', 'second', 'third']; + + crankMethods.rollbackCrank('second'); + + expect(kdb.rollbackSavepoint).toHaveBeenCalledWith('t1'); + expect(context.savepoints).toStrictEqual(['first']); + }); + + it('should throw when savepoint does not exist', () => { + context.inCrank = true; + context.savepoints = ['first', 'second']; + + expect(() => crankMethods.rollbackCrank('nonexistent')).toThrow( + 'no such savepoint as ""nonexistent""', + ); + }); + + it('should throw when not in a crank', () => { + expect(() => crankMethods.rollbackCrank('test')).toThrow( + 'rollbackCrank outside of crank', + ); + }); + }); + + describe('endCrank', () => { + it('should set inCrank to false', () => { + context.inCrank = true; + crankMethods.endCrank(); + + expect(context.inCrank).toBe(false); + }); + + it('should release savepoints if they exist', () => { + context.inCrank = true; + context.savepoints = ['test']; + + crankMethods.endCrank(); + + expect(kdb.releaseSavepoint).toHaveBeenCalledWith('t0'); + expect(context.savepoints).toStrictEqual([]); + }); + + it('should not call releaseSavepoint if no savepoints exist', () => { + context.inCrank = true; + + crankMethods.endCrank(); + + expect(kdb.releaseSavepoint).not.toHaveBeenCalled(); + }); + + it('should throw when not in a crank', () => { + expect(() => crankMethods.endCrank()).toThrow( + 'endCrank outside of crank', + ); + }); + }); + + describe('releaseAllSavepoints', () => { + it('should release all savepoints', () => { + context.inCrank = true; + context.savepoints = ['test']; + crankMethods.releaseAllSavepoints(); + expect(kdb.releaseSavepoint).toHaveBeenCalledWith('t0'); + expect(context.savepoints).toStrictEqual([]); + }); + + it('should not call releaseSavepoint if no savepoints exist', () => { + context.inCrank = true; + crankMethods.releaseAllSavepoints(); + expect(kdb.releaseSavepoint).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/ocap-kernel/src/store/methods/crank.ts b/packages/ocap-kernel/src/store/methods/crank.ts new file mode 100644 index 000000000..b5b8f900c --- /dev/null +++ b/packages/ocap-kernel/src/store/methods/crank.ts @@ -0,0 +1,81 @@ +import { Fail, q } from '@endo/errors'; +import type { KernelDatabase } from '@metamask/kernel-store'; + +import type { StoreContext } from '../types.ts'; + +/** + * Get the crank methods. + * + * @param ctx - The store context. + * @param kdb - The kernel database. + * @returns The crank methods. + */ +// eslint-disable-next-line @typescript-eslint/explicit-function-return-type +export function getCrankMethods(ctx: StoreContext, kdb: KernelDatabase) { + /** + * Start a crank. + */ + function startCrank(): void { + !ctx.inCrank || Fail`startCrank while already in a crank`; + ctx.inCrank = true; + } + + /** + * Create a savepoint in the crank. + * + * @param name - The savepoint name. + */ + function createCrankSavepoint(name: string): void { + ctx.inCrank || Fail`createCrankSavepoint outside of crank`; + const ordinal = ctx.savepoints.length; + ctx.savepoints.push(name); + kdb.createSavepoint(`t${ordinal}`); + } + + /** + * Rollback a crank. + * + * @param savepoint - The savepoint name. + */ + function rollbackCrank(savepoint: string): void { + ctx.inCrank || Fail`rollbackCrank outside of crank`; + for (const ordinal of ctx.savepoints.keys()) { + if (ctx.savepoints[ordinal] === savepoint) { + kdb.rollbackSavepoint(`t${ordinal}`); + ctx.savepoints.length = ordinal; + return; + } + } + Fail`no such savepoint as "${q(savepoint)}"`; + } + + /** + * End a crank. + */ + function endCrank(): void { + ctx.inCrank || Fail`endCrank outside of crank`; + if (ctx.savepoints.length > 0) { + kdb.releaseSavepoint('t0'); + ctx.savepoints.length = 0; + } + ctx.inCrank = false; + } + + /** + * Release all savepoints. + */ + function releaseAllSavepoints(): void { + if (ctx.savepoints.length > 0) { + kdb.releaseSavepoint('t0'); + ctx.savepoints.length = 0; + } + } + + return { + startCrank, + createCrankSavepoint, + rollbackCrank, + endCrank, + releaseAllSavepoints, + }; +} diff --git a/packages/ocap-kernel/src/store/methods/vat.test.ts b/packages/ocap-kernel/src/store/methods/vat.test.ts index d86c0c35d..4a9ec6906 100644 --- a/packages/ocap-kernel/src/store/methods/vat.test.ts +++ b/packages/ocap-kernel/src/store/methods/vat.test.ts @@ -509,4 +509,30 @@ describe('vat store methods', () => { ); }); }); + + describe('isVatActive', () => { + it('returns true when vat configuration exists', () => { + mockKV.set(`vatConfig.${vatID1}`, JSON.stringify(vatConfig1)); + + const result = vatMethods.isVatActive(vatID1); + + expect(result).toBe(true); + }); + + it('returns false when vat configuration does not exist', () => { + const result = vatMethods.isVatActive(vatID1); + + expect(result).toBe(false); + }); + + it('returns false after vat configuration is deleted', () => { + mockKV.set(`vatConfig.${vatID1}`, JSON.stringify(vatConfig1)); + expect(vatMethods.isVatActive(vatID1)).toBe(true); + + mockKV.delete(`vatConfig.${vatID1}`); + + const result = vatMethods.isVatActive(vatID1); + expect(result).toBe(false); + }); + }); }); diff --git a/packages/ocap-kernel/src/store/methods/vat.ts b/packages/ocap-kernel/src/store/methods/vat.ts index 8850095eb..4d29d970a 100644 --- a/packages/ocap-kernel/src/store/methods/vat.ts +++ b/packages/ocap-kernel/src/store/methods/vat.ts @@ -91,6 +91,16 @@ export function getVatMethods(ctx: StoreContext) { ) as VatConfig; } + /** + * Check if a vat is active. + * + * @param vatID - The ID of the vat to check. + * @returns True if the vat is active, false otherwise. + */ + function isVatActive(vatID: VatId): boolean { + return kv.get(`${VAT_CONFIG_BASE}${vatID}`) !== undefined; + } + /** * Store the configuration for a vat. * @@ -349,5 +359,6 @@ export function getVatMethods(ctx: StoreContext) { cleanupTerminatedVat, nextTerminatedVatCleanup, exportFromVat, + isVatActive, }; } diff --git a/packages/ocap-kernel/src/store/types.ts b/packages/ocap-kernel/src/store/types.ts index a2ce57e7c..cf382bf89 100644 --- a/packages/ocap-kernel/src/store/types.ts +++ b/packages/ocap-kernel/src/store/types.ts @@ -14,6 +14,8 @@ export type StoreContext = { gcActions: StoredValue; reapQueue: StoredValue; terminatedVats: StoredValue; + inCrank: boolean; + savepoints: string[]; }; export type StoredValue = { diff --git a/packages/ocap-kernel/src/types.ts b/packages/ocap-kernel/src/types.ts index c2adef5c0..58a69efa2 100644 --- a/packages/ocap-kernel/src/types.ts +++ b/packages/ocap-kernel/src/types.ts @@ -1,9 +1,11 @@ import type { + SwingSetCapData, Message as SwingsetMessage, VatSyscallObject, VatSyscallSend, } from '@agoric/swingset-liveslots'; import type { CapData } from '@endo/marshal'; +import type { VatCheckpoint } from '@metamask/kernel-store'; import type { JsonRpcMessage } from '@metamask/kernel-utils'; import type { DuplexStream } from '@metamask/streams'; import { @@ -249,10 +251,11 @@ export type VatWorkerService = { * Terminate a worker identified by its vat id. * * @param vatId - The vat id of the worker to terminate. + * @param error - An optional error to terminate the worker with. * @returns A promise that resolves when the worker has terminated * or rejects if that worker does not exist. */ - terminate: (vatId: VatId) => Promise; + terminate: (vatId: VatId, error?: Error) => Promise; /** * Terminate all workers managed by the service. * @@ -377,3 +380,11 @@ export const GCActionStruct = define('GCAction', (value: unknown) => { export const isGCAction = (value: unknown): value is GCAction => is(value, GCActionStruct); + +export type CrankResults = { + didDelivery?: VatId; // the vat on which we made a delivery + abort?: boolean; // changes should be discarded, not committed + terminate?: { vatId: VatId; reject: boolean; info: SwingSetCapData }; +}; + +export type VatDeliveryResult = [VatCheckpoint, string | null]; diff --git a/packages/ocap-kernel/test/storage.ts b/packages/ocap-kernel/test/storage.ts index 9d2fc04db..41f9abd2e 100644 --- a/packages/ocap-kernel/test/storage.ts +++ b/packages/ocap-kernel/test/storage.ts @@ -127,5 +127,14 @@ export function makeMapKernelDatabase(): KernelDatabase { deleteVatStore: (vatID: string) => { vatStores.delete(vatID); }, + createSavepoint: () => { + // noop + }, + rollbackSavepoint: () => { + // noop + }, + releaseSavepoint: () => { + // noop + }, }; } diff --git a/vitest.config.ts b/vitest.config.ts index 68d090729..566c29108 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -110,10 +110,10 @@ export default defineConfig({ lines: 0, }, 'packages/kernel-store/**': { - statements: 92.44, - functions: 91.17, - branches: 84.78, - lines: 92.39, + statements: 97.99, + functions: 100, + branches: 91.25, + lines: 97.98, }, 'packages/kernel-utils/**': { statements: 100, @@ -134,10 +134,10 @@ export default defineConfig({ lines: 73.58, }, 'packages/ocap-kernel/**': { - statements: 91.58, - functions: 94.96, - branches: 81.89, - lines: 91.56, + statements: 91.39, + functions: 95.09, + branches: 81.99, + lines: 91.37, }, 'packages/streams/**': { statements: 100,