diff --git a/packages/kernel-agents/CHANGELOG.md b/packages/kernel-agents/CHANGELOG.md new file mode 100644 index 000000000..0c82cb1ed --- /dev/null +++ b/packages/kernel-agents/CHANGELOG.md @@ -0,0 +1,10 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +[Unreleased]: https://github.com/MetaMask/ocap-kernel/ diff --git a/packages/kernel-agents/README.md b/packages/kernel-agents/README.md new file mode 100644 index 000000000..ec6495ebb --- /dev/null +++ b/packages/kernel-agents/README.md @@ -0,0 +1,41 @@ +# `@ocap/kernel-agents` + +Capability-enabled, language-model-flow-controlled programming. + +## Installation + +`yarn add @ocap/kernel-agents` + +or + +`npm install @ocap/kernel-agents` + +## Contributing + +This package is part of a monorepo. Instructions for contributing can be found in the [monorepo README](https://github.com/MetaMask/ocap-kernel#readme). + +## Running E2E Tests + +The end to end tests assume an [ollama](https://ollama.com/) server is running on `localhost:11343` and has the [DEFAULT_MODEL](./test/constants.ts) already pulled. + +### Pulling an Ollama model (CLI) + +`ollama pull 'llama3.1:latest'` + +### Pulling an Ollama model (curl) + +```sh +curl -X POST http://localhost:11434/api/pull -d '{ + "name": "llama3.1:latest" +}' +``` + +### Test Commands + +To run the test suite, use the `yarn test:e2e` command. Ollama configuration errors will be detected by the [suite tests](./test/e2e/suite.test.ts). + +To observe intermediate steps, including prompts provided to the agent, use the `--no-silent` flag. + +```sh +yarn test:e2e --no-silent +``` diff --git a/packages/kernel-agents/package.json b/packages/kernel-agents/package.json new file mode 100644 index 000000000..bab12f4cf --- /dev/null +++ b/packages/kernel-agents/package.json @@ -0,0 +1,89 @@ +{ + "name": "@ocap/kernel-agents", + "version": "0.0.0", + "private": true, + "description": "Capability-enabled, language-model-flow-controlled programming", + "homepage": "https://github.com/MetaMask/ocap-kernel/tree/main/packages/kernel-agents#readme", + "bugs": { + "url": "https://github.com/MetaMask/ocap-kernel/issues" + }, + "repository": { + "type": "git", + "url": "https://github.com/MetaMask/ocap-kernel.git" + }, + "type": "module", + "exports": { + ".": { + "import": { + "types": "./dist/index.d.mts", + "default": "./dist/index.mjs" + }, + "require": { + "types": "./dist/index.d.cts", + "default": "./dist/index.cjs" + } + }, + "./package.json": "./package.json" + }, + "files": [ + "dist/" + ], + "scripts": { + "build": "ts-bridge --project tsconfig.build.json --no-references --clean", + "build:docs": "typedoc", + "changelog:validate": "../../scripts/validate-changelog.sh @ocap/kernel-agents", + "clean": "rimraf --glob './*.tsbuildinfo' ./.eslintcache ./coverage ./dist", + "lint": "yarn lint:eslint && yarn lint:misc --check && yarn constraints && yarn lint:dependencies", + "lint:dependencies": "depcheck", + "lint:eslint": "eslint . --cache", + "lint:fix": "yarn lint:eslint --fix && yarn lint:misc --write && yarn constraints --fix && yarn lint:dependencies", + "lint:misc": "prettier --no-error-on-unmatched-pattern '**/*.json' '**/*.md' '**/*.html' '!**/CHANGELOG.old.md' '**/*.yml' '!.yarnrc.yml' '!merged-packages/**' --ignore-path ../../.gitignore", + "publish:preview": "yarn npm publish --tag preview", + "test": "vitest run --config vitest.config.ts", + "test:e2e": "vitest run --config vitest.config.e2e.ts", + "test:clean": "yarn test --no-cache --coverage.clean", + "test:dev": "yarn test --mode development", + "test:verbose": "yarn test --reporter verbose", + "test:watch": "vitest --config vitest.config.ts" + }, + "devDependencies": { + "@arethetypeswrong/cli": "^0.17.4", + "@metamask/auto-changelog": "^5.0.1", + "@metamask/eslint-config": "^14.0.0", + "@metamask/eslint-config-nodejs": "^14.0.0", + "@metamask/eslint-config-typescript": "^14.0.0", + "@ocap/repo-tools": "workspace:^", + "@ts-bridge/cli": "^0.6.3", + "@ts-bridge/shims": "^0.1.1", + "@types/node": "^22.13.1", + "@typescript-eslint/eslint-plugin": "^8.29.0", + "@typescript-eslint/parser": "^8.29.0", + "@typescript-eslint/utils": "^8.29.0", + "@vitest/eslint-plugin": "^1.3.4", + "depcheck": "^1.4.7", + "eslint": "^9.23.0", + "eslint-config-prettier": "^10.1.1", + "eslint-import-resolver-typescript": "^4.3.1", + "eslint-plugin-import-x": "^4.10.0", + "eslint-plugin-jsdoc": "^50.6.9", + "eslint-plugin-n": "^17.17.0", + "eslint-plugin-prettier": "^5.2.6", + "eslint-plugin-promise": "^7.2.1", + "prettier": "^3.5.3", + "rimraf": "^6.0.1", + "turbo": "^2.5.6", + "typedoc": "^0.28.1", + "typescript": "~5.8.2", + "typescript-eslint": "^8.29.0", + "vite": "^7.1.2", + "vitest": "^3.2.4" + }, + "engines": { + "node": "^20.6 || >=22" + }, + "dependencies": { + "@metamask/kernel-utils": "workspace:^", + "@metamask/logger": "workspace:^", + "@ocap/kernel-language-model-service": "workspace:^" + } +} diff --git a/packages/kernel-agents/src/agent.test.ts b/packages/kernel-agents/src/agent.test.ts new file mode 100644 index 000000000..4bb8a1389 --- /dev/null +++ b/packages/kernel-agents/src/agent.test.ts @@ -0,0 +1,124 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; + +import type { Logger } from '@metamask/logger'; +import { vi, describe, it, expect } from 'vitest'; + +import { makeAgent } from './agent.ts'; +import { capability } from './capability.ts'; +import { end } from './default-capabilities.ts'; +import { AssistantMessage, CapabilityResultMessage } from './messages.ts'; +import { makeChat } from './prompt.ts'; + +const prompt = 'test prompt'; +const prefix = '{"messageType":"assistant","'; + +vi.mock('./prompt.ts', () => ({ + makeChat: vi.fn(() => ({ + getPromptAndPrefix: vi.fn(() => ({ prompt, prefix })), + pushMessages: vi.fn(), + })), +})); + +describe('makeAgent', () => { + const mockLlm = (...chunks: string[]) => ({ + getInfo: vi.fn(), + load: vi.fn(), + unload: vi.fn(), + sample: vi.fn().mockResolvedValue({ + stream: { + async *[Symbol.asyncIterator]() { + for (const chunk of chunks) { + yield { response: chunk }; + } + }, + }, + abort: vi.fn(), + }), + }); + + it('makes an agent', () => { + const llm = mockLlm(); + const agent = makeAgent({ llm, capabilities: {} }); + expect(agent).toBeDefined(); + expect(agent).toHaveProperty('task'); + }); + + it('endows the "end" capability by default', async () => { + const llm = mockLlm(); + const mockMergeDisjointRecordsSpy = vi.spyOn( + await import('@metamask/kernel-utils'), + 'mergeDisjointRecords', + ); + const capabilities = {}; + makeAgent({ llm, capabilities }); + expect(mockMergeDisjointRecordsSpy).toHaveBeenCalledWith( + { end }, + capabilities, + ); + }); + + describe('task', () => { + it('invokes the LLM', async () => { + const llm = mockLlm(`invoke":[{"name":"end","args":{"final":"x"}}]}`); + const agent = makeAgent({ llm, capabilities: {} }); + const result = await agent.task(''); + expect(result).toBe('x'); + // This is a massive understatement, but we don't want to test the prompt + expect(llm.sample).toHaveBeenCalledWith(prompt); + }); + + it('throws if the LLM did not invoke a capability', async () => { + // LLM finishes valid JSON, but no invoke property + const llm = mockLlm(`content":""}`); + const agent = makeAgent({ llm, capabilities: {} }); + const task = agent.task(''); + await expect(task).rejects.toThrow('No invoke in result'); + }); + + it('throws if invocation budget is exceeded', async () => { + const llm = mockLlm(`invoke":[{"name":"end","args":{"final":"x"}}]}`); + const agent = makeAgent({ llm, capabilities: {} }); + const task = agent.task('', { invocationBudget: 0 }); + await expect(task).rejects.toThrow('Invocation budget exceeded'); + }); + + // XXX This test reflects a poor factorization of the agent. + it('pushes messages to the transcript', async () => { + const llm = mockLlm(`invoke":[{"name":"test","args":{}}]}`); + const pushMessages = vi.fn(); + vi.mocked(makeChat).mockReturnValue({ + getPromptAndPrefix: vi.fn(() => ({ prompt, prefix })), + pushMessages, + }); + const { makeAgent: makeAgent2 } = await import('./agent.ts'); + const agent = makeAgent2({ + llm, + capabilities: { + test: capability(async () => 'test', { + description: 'test', + args: {}, + returns: { type: 'string' }, + }), + }, + }); + const task = agent.task('test', { invocationBudget: 1 }); + await expect(task).rejects.toThrow('Invocation budget exceeded'); + expect(pushMessages).toHaveBeenCalledWith( + expect.any(AssistantMessage), + expect.any(CapabilityResultMessage), + ); + }); + + it('logs to the provided logger', async () => { + const llm = mockLlm(`invoke":[{"name":"end","args":{"final":"x"}}]}`); + const logger = { + info: vi.fn(), + subLogger: vi.fn(() => logger), + } as unknown as Logger; + const agent = makeAgent({ llm, capabilities: {}, logger }); + await agent.task('test', { invocationBudget: 1 }); + expect(logger.info).toHaveBeenCalledWith('query:', 'test'); + expect(logger.subLogger).toHaveBeenCalledWith({ tags: ['t001'] }); + }); + }); +}); diff --git a/packages/kernel-agents/src/agent.ts b/packages/kernel-agents/src/agent.ts new file mode 100644 index 000000000..af2258add --- /dev/null +++ b/packages/kernel-agents/src/agent.ts @@ -0,0 +1,98 @@ +import { makeCounter, mergeDisjointRecords } from '@metamask/kernel-utils'; +import type { Logger } from '@metamask/logger'; +import type { LanguageModel } from '@ocap/kernel-language-model-service'; + +import { invokeCapabilities } from './capability.ts'; +import { end } from './default-capabilities.ts'; +import { AssistantMessage, CapabilityResultMessage } from './messages.ts'; +import type { AssistantMessageJson } from './messages.ts'; +import { gatherStreamingResponse, makeIncrementalParser } from './parser.ts'; +import { makeChat } from './prompt.ts'; +import type { Agent, CapabilityRecord } from './types.ts'; + +/** + * Make a capability-augmented agent + * + * @param args - The arguments to make the agent. + * @param args.llm - The language model to use for the agent + * @param args.capabilities - The agent's capabilities + * @param args.logger - The logger to use for the agent + * @returns A kernel agent + */ +export const makeAgent = ({ + llm, + capabilities, + logger, +}: { + llm: LanguageModel; + capabilities: CapabilityRecord; + logger?: Logger; +}): Agent => { + const agentCapabilities = mergeDisjointRecords( + { end }, + capabilities, + ) as CapabilityRecord; + + const taskCounter = makeCounter(); + + return { + task: async ( + query: string, + { invocationBudget = 10 }: { invocationBudget?: number } = {}, + ) => { + // XXX Tasks could be integrated deeper in the kernel + const taskId = `t${taskCounter().toString().padStart(3, '0')}`; + const taskLogger = logger?.subLogger({ tags: [taskId] }); + taskLogger?.info('query:', query); + + const { getPromptAndPrefix, pushMessages } = makeChat( + agentCapabilities, + query, + ); + + for (let invocation = 0; invocation < invocationBudget; invocation++) { + taskLogger?.info(`begin invocation ${invocation}/${invocationBudget}`); + + const { prompt, prefix } = getPromptAndPrefix(); + const parse = makeIncrementalParser({ + prefix, + ...(taskLogger ? { logger: taskLogger } : {}), + }); + taskLogger?.info('prompt:', prompt); + + const { stream, abort } = await llm.sample(prompt); + let assistantMessage: AssistantMessageJson; + try { + assistantMessage = await gatherStreamingResponse({ + stream, + parse, + }); + } finally { + // Stop the LLM from generating anymore + await abort(); + } + taskLogger?.info('assistantMessage:', assistantMessage); + + // TODO: this should already be validated by the parser + if (!assistantMessage.invoke) { + throw new Error('No invoke in result'); + } + const results = await invokeCapabilities( + assistantMessage.invoke, + agentCapabilities, + ); + logger?.info('results:', results); + const didEnd = results.find((capability) => capability.name === 'end'); + if (didEnd) { + logger?.info('exit invocation with result:', didEnd.result); + return didEnd.result; + } + pushMessages( + new AssistantMessage(assistantMessage), + new CapabilityResultMessage(results), + ); + } + throw new Error('Invocation budget exceeded'); + }, + }; +}; diff --git a/packages/kernel-agents/src/capability.test.ts b/packages/kernel-agents/src/capability.test.ts new file mode 100644 index 000000000..709930d54 --- /dev/null +++ b/packages/kernel-agents/src/capability.test.ts @@ -0,0 +1,38 @@ +import { describe, it, expect } from 'vitest'; + +import { capability, invokeCapabilities } from './capability.ts'; + +describe('capability', () => { + it('creates a capability', () => { + const testCapability = capability(async () => Promise.resolve('test'), { + description: 'a test capability', + args: {}, + }); + expect(testCapability).toStrictEqual({ + func: expect.any(Function), + schema: { description: 'a test capability', args: {} }, + }); + }); +}); + +describe('invokeCapabilities', () => { + it("invokes the assistant's chosen capability", async () => { + const testCapability = capability(async () => Promise.resolve('test'), { + description: 'a test capability', + args: {}, + }); + const result = await invokeCapabilities( + [{ name: 'testCapability', args: {} }], + { testCapability }, + ); + expect(result).toStrictEqual([ + { name: 'testCapability', args: {}, result: 'test' }, + ]); + }); + + it('throws if the capability is not found', async () => { + await expect( + invokeCapabilities([{ name: 'testCapability', args: {} }], {}), + ).rejects.toThrow('Invoked capability testCapability not found'); + }); +}); diff --git a/packages/kernel-agents/src/capability.ts b/packages/kernel-agents/src/capability.ts new file mode 100644 index 000000000..a379daaac --- /dev/null +++ b/packages/kernel-agents/src/capability.ts @@ -0,0 +1,50 @@ +import type { Invocation } from './messages.ts'; +import type { + Capability, + CapabilityRecord, + CapabilitySchema, + CapabilitySpec, + ExtractRecordKeys, +} from './types.ts'; + +export const capability = , Return = null>( + func: Capability, + schema: CapabilitySchema>, +): CapabilitySpec => ({ func, schema }); + +type SchemaEntry = [string, { schema: CapabilitySchema }]; +/** + * Extract only the serializable schemas from the capabilities + * + * @param capabilities - The capabilities to extract the schemas from + * @returns A record mapping capability names to their schemas + */ +export const extractCapabilitySchemas = ( + capabilities: CapabilityRecord, +): Record< + keyof typeof capabilities, + (typeof capabilities)[keyof typeof capabilities]['schema'] +> => + Object.fromEntries( + (Object.entries(capabilities) as unknown as SchemaEntry[]).map( + ([name, { schema }]) => [name, schema], + ), + ); + +export const invokeCapabilities = async ( + invocations: Invocation[], + capabilities: CapabilityRecord, +): Promise<(Invocation & { result: unknown })[]> => + await Promise.all( + invocations.map(async ({ name, args }) => ({ + name, + args, + result: await (async () => { + const toInvoke = capabilities[name]; + if (!toInvoke) { + throw new Error(`Invoked capability ${name} not found`); + } + return await toInvoke.func(args as never); + })(), + })), + ); diff --git a/packages/kernel-agents/src/default-capabilities.ts b/packages/kernel-agents/src/default-capabilities.ts new file mode 100644 index 000000000..17e1b0039 --- /dev/null +++ b/packages/kernel-agents/src/default-capabilities.ts @@ -0,0 +1,12 @@ +import { capability } from './capability.ts'; + +export const end = capability(async ({ final }: { final: string }) => final, { + description: 'Return a final response to the user.', + args: { + final: { + type: 'string', + description: + 'A concise final response that restates the requested information', + }, + }, +}); diff --git a/packages/kernel-agents/src/example-capabilities.test.ts b/packages/kernel-agents/src/example-capabilities.test.ts new file mode 100644 index 000000000..b03870ee0 --- /dev/null +++ b/packages/kernel-agents/src/example-capabilities.test.ts @@ -0,0 +1,25 @@ +import { describe, it, expect } from 'vitest'; + +import { exampleCapabilities } from './example-capabilities.ts'; + +describe('exampleCapabilities', () => { + it('contains the correct capabilities', () => { + expect(exampleCapabilities).toBeDefined(); + expect(Object.keys(exampleCapabilities)).toStrictEqual([ + 'count', + 'add', + 'multiply', + ]); + }); + + it.each([ + ['count', { word: 'abcdefg' }, 7], + ['add', { summands: [1, 2, 3, 4] }, 10], + ['multiply', { factors: [1, 2, 3, 4] }, 24], + ])('%s(%s) = %s', async (name, args, expected) => { + const capability = + exampleCapabilities[name as keyof typeof exampleCapabilities]; + expect(capability).toBeDefined(); + expect(await capability.func(args as never)).toStrictEqual(expected); + }); +}); diff --git a/packages/kernel-agents/src/example-capabilities.ts b/packages/kernel-agents/src/example-capabilities.ts new file mode 100644 index 000000000..9305e24c1 --- /dev/null +++ b/packages/kernel-agents/src/example-capabilities.ts @@ -0,0 +1,47 @@ +import { capability } from './capability.ts'; + +export const count = capability( + async ({ word }: { word: string }) => word.length, + { + description: 'Count the number of characters in an arbitrary string', + args: { + word: { type: 'string', description: 'The string to get the length of.' }, + }, + returns: { + type: 'number', + description: 'The number of characters in the string.', + }, + }, +); + +export const add = capability( + async ({ summands }: { summands: number[] }) => + summands.reduce((acc, summand) => acc + summand, 0), + { + description: 'Add a list of numbers.', + args: { summands: { type: 'array', item: { type: 'number' } } }, + returns: { type: 'number', description: 'The sum of the numbers.' }, + }, +); + +export const multiply = capability( + async ({ factors }: { factors: number[] }) => + factors.reduce((acc, factor) => acc * factor, 1), + { + description: 'Multiply a list of numbers.', + args: { + factors: { + type: 'array', + description: 'The list of numbers to multiply.', + item: { type: 'number' }, + }, + }, + returns: { type: 'number', description: 'The product of the factors.' }, + }, +); + +export const exampleCapabilities = { + count, + add, + multiply, +}; diff --git a/packages/kernel-agents/src/example-transcripts.ts b/packages/kernel-agents/src/example-transcripts.ts new file mode 100644 index 000000000..9099a2e23 --- /dev/null +++ b/packages/kernel-agents/src/example-transcripts.ts @@ -0,0 +1,68 @@ +/** + * Example transcripts for the prompt + */ +import { extractCapabilitySchemas } from './capability.ts'; +import { end as endCapability } from './default-capabilities.ts'; +import { exampleCapabilities } from './example-capabilities.ts'; +import { + CapabilitySpecMessage, + UserMessage, + AssistantMessage, + CapabilityResultMessage, +} from './messages.ts'; +import type { Transcript } from './messages.ts'; + +const { end, count, add, multiply } = extractCapabilitySchemas({ + ...exampleCapabilities, + end: endCapability, +}); + +const simpleSemanticTask: Transcript = [ + new CapabilitySpecMessage({ end, add }), + new UserMessage('What color is a banana?'), + new AssistantMessage({ + think: [ + 'Bananas can be either yellow or green, depending on the variety and ripeness.', + 'Typically, people think of yellow bananas when they think of bananas.', + 'I should give the typical response, but clarify that I am assuming the banana is ripe.', + ], + invoke: [{ name: 'end', args: { final: 'A banana is yellow when ripe.' } }], + }), +]; + +const multiStepCalculation: Transcript = [ + new CapabilitySpecMessage({ end, count, multiply }), + new UserMessage( + 'What is the size of a matrix with rows indexed by the letters of "piano" and columns by the letters of "guitar"?', + ), + new AssistantMessage({ + think: [ + 'I need to find the size of a matrix with rows indexed by the letters of "piano" and columns by the letters of "guitar".', + 'The answer will be the product of the length of the word "piano" and the length of the word "guitar".', + 'To prove my answer, I will count the lengths of the words using the "count" capability, then multiply the results using the "multiply" capability.', + ], + invoke: [ + { name: 'count', args: { word: 'piano' } }, + { name: 'count', args: { word: 'guitar' } }, + ], + }), + new CapabilityResultMessage([ + { name: 'count', args: { word: 'piano' }, result: 5 }, + { name: 'count', args: { word: 'guitar' }, result: 6 }, + ]), + new AssistantMessage({ + think: ['Now I can multiply the results to get the answer.'], + invoke: [{ name: 'multiply', args: { factors: [5, 6] } }], + }), + new CapabilityResultMessage([ + { name: 'multiply', args: { factors: [5, 6] }, result: 30 }, + ]), + new AssistantMessage({ + think: ['Now I can return the answer.'], + invoke: [ + { name: 'end', args: { final: 'Such a matrix would have 30 elements.' } }, + ], + }), +]; + +export const exampleTranscripts = [simpleSemanticTask, multiStepCalculation]; diff --git a/packages/kernel-agents/src/index.test.ts b/packages/kernel-agents/src/index.test.ts new file mode 100644 index 000000000..3f79c6c12 --- /dev/null +++ b/packages/kernel-agents/src/index.test.ts @@ -0,0 +1,12 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; +import { describe, it, expect } from 'vitest'; + +import * as indexModule from './index.ts'; + +describe('index', () => { + it('has the expected exports', () => { + expect(Object.keys(indexModule).sort()).toStrictEqual( + expect.arrayContaining(['makeAgent']), + ); + }); +}); diff --git a/packages/kernel-agents/src/index.ts b/packages/kernel-agents/src/index.ts new file mode 100644 index 000000000..8fb5c9191 --- /dev/null +++ b/packages/kernel-agents/src/index.ts @@ -0,0 +1,2 @@ +export { makeAgent } from './agent.ts'; +export type { CapabilityRecord } from './types.ts'; diff --git a/packages/kernel-agents/src/message.test.ts b/packages/kernel-agents/src/message.test.ts new file mode 100644 index 000000000..8a3fe8519 --- /dev/null +++ b/packages/kernel-agents/src/message.test.ts @@ -0,0 +1,33 @@ +import { describe, it, expect } from 'vitest'; + +import { AssistantMessage } from './messages.ts'; + +describe('AssistantMessage', () => { + it('should create an assistant message', () => { + const message = new AssistantMessage({ think: ['test'], invoke: [] }); + expect(message).toBeDefined(); + }); + + it('serializes think before invoke if present', () => { + const message = new AssistantMessage({ + invoke: [{ name: 'test', args: {} }], + think: ['test'], + }); + const json = message.toJSON(); + const [left, right] = json.split('think'); + expect(left).toContain('messageType'); + expect(left).not.toContain('invoke'); + expect(right).not.toContain('messageType'); + expect(right).toContain('invoke'); + }); + + it('serializes if think is not present', () => { + const message = new AssistantMessage({ + invoke: [{ name: 'test', args: {} }], + }); + const json = message.toJSON(); + expect(json).toContain('messageType'); + expect(json).not.toContain('think'); + expect(json).toContain('invoke'); + }); +}); diff --git a/packages/kernel-agents/src/messages.ts b/packages/kernel-agents/src/messages.ts new file mode 100644 index 000000000..7dbb4799f --- /dev/null +++ b/packages/kernel-agents/src/messages.ts @@ -0,0 +1,87 @@ +export type MessageType = + | 'capabilitySpecification' + | 'user' + | 'assistant' + | 'capabilityResult'; + +export class Message< + Type extends MessageType, + Body extends Record, +> { + messageType: Type; + + messageBody: Body; + + constructor(messageType: Type, messageBody: Body) { + this.messageType = messageType; + this.messageBody = messageBody; + } + + toJSON(): string { + return JSON.stringify({ + ...this.messageBody, + messageType: this.messageType, + }); + } +} + +export type Transcript = Message>[]; + +export class CapabilitySpecMessage extends Message< + 'capabilitySpecification', + { schemas: object } +> { + constructor(schemas: object) { + super('capabilitySpecification', { schemas }); + } +} + +export class UserMessage extends Message<'user', { content: string }> { + constructor(content: string) { + super('user', { content }); + } +} + +export type Invocation = { name: string; args: object }; + +export class AssistantMessage extends Message< + 'assistant', + { think?: string[]; invoke: Invocation[] } +> { + constructor({ think, invoke }: { think?: string[]; invoke: Invocation[] }) { + super('assistant', { think: think ?? [], invoke }); + } + + toJSON(): string { + /* JSON.stringify will not preserve the order of the properties. + * To utilize the conditional probability, think precedes invoke. + * Manual serialization prints the properties in prompt order. + */ + const messageType = '"messageType":"assistant",'; + const think = this.messageBody.think?.length + ? `"think":${JSON.stringify(this.messageBody.think)},` + : ''; + const invoke = `"invoke":${JSON.stringify(this.messageBody.invoke)}`; + return ['{', messageType, think, invoke, '}'].join(''); + } +} + +export type AssistantMessageJson = { + messageType: 'assistant'; + think?: string[]; + invoke: Invocation[]; +}; + +export class CapabilityResultMessage extends Message< + 'capabilityResult', + { results: (Invocation & { result: unknown })[] } +> { + constructor(results: (Invocation & { result: unknown })[]) { + super('capabilityResult', { results }); + } +} + +export type CapabilityResultMessageJson = { + messageType: 'capabilityResult'; + results: (Invocation & { result: unknown })[]; +}; diff --git a/packages/kernel-agents/src/parser.test.ts b/packages/kernel-agents/src/parser.test.ts new file mode 100644 index 000000000..6fe6ea68e --- /dev/null +++ b/packages/kernel-agents/src/parser.test.ts @@ -0,0 +1,88 @@ +import type { Logger } from '@metamask/logger'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; + +import { makeIncrementalParser, gatherStreamingResponse } from './parser.ts'; + +describe('makeIncrementalParser', () => { + let mockLogger: Logger; + + beforeEach(() => { + mockLogger = { + info: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + debug: vi.fn(), + } as unknown as Logger; + }); + + it('parses complete JSON in single chunk', () => { + const parser = makeIncrementalParser({}); + expect(parser('{"key": "value"}')).toStrictEqual({ key: 'value' }); + }); + + it('parses JSON across multiple chunks', () => { + const parser = makeIncrementalParser({}); + expect(parser('{"key": "val')).toBeNull(); + expect(parser('ue", "content": 42}')).toStrictEqual({ + key: 'value', + content: 42, + }); + }); + + it('parses JSON with prefix', () => { + const parser = makeIncrementalParser({ prefix: '{"start": true, ' }); + expect(parser('"end": false}')).toStrictEqual({ start: true, end: false }); + }); + + it('logs parsing attempts when logger provided', () => { + const parser = makeIncrementalParser({ logger: mockLogger }); + parser('{"test": "value"}'); + expect(mockLogger.info).toHaveBeenCalledWith( + 'toParse:', + '{"test": "value"}', + ); + }); + + it('throws error for invalid JSON', () => { + const parser = makeIncrementalParser({}); + expect(() => parser('{"invalid": json}')).toThrow('not valid JSON'); + }); + + it('throws error when max chunk count exceeded', () => { + const parser = makeIncrementalParser({ maxChunkCount: 2 }); + parser('chunk1'); + parser('chunk2'); + expect(() => parser('chunk3')).toThrow('Max chunk count reached'); + }); +}); + +describe('gatherStreamingResponse', () => { + it('gathers complete response from single chunk', async () => { + const stream = (async function* () { + yield { response: '{"key": "value"}' }; + })(); + const parser = makeIncrementalParser({}); + const result = await gatherStreamingResponse({ stream, parse: parser }); + expect(result).toStrictEqual({ key: 'value' }); + }); + + it('gathers response from multiple chunks', async () => { + const stream = (async function* () { + yield { response: '{"key": "val' }; + yield { response: 'ue", "content": 42}' }; + })(); + const parser = makeIncrementalParser({}); + const result = await gatherStreamingResponse({ stream, parse: parser }); + expect(result).toStrictEqual({ key: 'value', content: 42 }); + }); + + it('throws error when stream ends without parse event', async () => { + const stream = (async function* () { + yield { response: 'incomplete json' }; + })(); + const parser = makeIncrementalParser({}); + await expect( + gatherStreamingResponse({ stream, parse: parser }), + ).rejects.toThrow('stream ended without a parse event'); + }); +}); diff --git a/packages/kernel-agents/src/parser.ts b/packages/kernel-agents/src/parser.ts new file mode 100644 index 000000000..a87e2465a --- /dev/null +++ b/packages/kernel-agents/src/parser.ts @@ -0,0 +1,81 @@ +import type { Logger } from '@metamask/logger'; + +export type MakeIncrementalParserArgs = { + prefix?: string; + maxChunkCount?: number; + logger?: Logger; +}; +export type IncrementalParser = ( + delta: string, +) => Result | null; +/** + * A quick and dirty 'incremental' parser for a streaming response. + * + * @param args - The arguments to make the incremental parser. + * @param args.prefix - The prefix to prepend to the response + * @param args.maxChunkCount - The maximum number of chunks to parse + * @param args.logger - The logger to use for the incremental parser + * @returns An async function that parses a delta of a streaming response, + * returning the result value if parsing is complete or null otherwise. + */ +export const makeIncrementalParser = ({ + prefix = '', + maxChunkCount = 200, + logger, +}: MakeIncrementalParserArgs): IncrementalParser => { + let response = prefix; + let chunkCount = 0; + let leftBracketCount = prefix.split('{').length - 1; + let rightBracketCount = prefix.split('}').length - 1; + return (delta: string) => { + chunkCount += 1; + const subchunks = delta.split('}'); + const lastSubchunk = subchunks.pop() as string; + for (const subchunk of subchunks) { + rightBracketCount += 1; + leftBracketCount += subchunk.split('{').length - 1; + response += `${subchunk}}`; + logger?.info('toParse:', response); + try { + return JSON.parse(response); + } catch (error) { + // XXX There are other ways to detect an irrecoverable state. + // This is the simplest. + if (leftBracketCount === rightBracketCount) { + throw error; + } + } + } + leftBracketCount += lastSubchunk.split('{').length - 1; + response += lastSubchunk; + if (maxChunkCount && chunkCount > maxChunkCount) { + throw new Error(`Max chunk count reached with response:\n${response}`); + } + return null; + }; +}; + +/** + * Gather a streaming response from an stream of chunks. + * + * @param args - The arguments to gather the streaming response. + * @param args.stream - The stream to gather from. + * @param args.parse - The incremental parser to use to parse the response. + * @returns The parsed response. + */ +export const gatherStreamingResponse = async ({ + stream, + parse, +}: { + stream: AsyncIterable<{ response: string }>; + parse: IncrementalParser; +}): Promise => { + for await (const chunk of stream) { + const delta = (chunk as { response: string }).response; + const parsed = parse(delta); + if (parsed !== null) { + return parsed; + } + } + throw new Error('stream ended without a parse event'); +}; diff --git a/packages/kernel-agents/src/prompt.test.ts b/packages/kernel-agents/src/prompt.test.ts new file mode 100644 index 000000000..15a8ab09e --- /dev/null +++ b/packages/kernel-agents/src/prompt.test.ts @@ -0,0 +1,29 @@ +import { describe, it, expect } from 'vitest'; + +import { AssistantMessage } from './messages.ts'; +import type { Transcript } from './messages.ts'; +import { makeChat } from './prompt.ts'; + +describe('makeChat', () => { + it('should make a chat', () => { + const chat = makeChat({}, 'test'); + expect(chat).toBeDefined(); + expect(chat).toHaveProperty('getPromptAndPrefix'); + expect(chat).toHaveProperty('pushMessages'); + }); + + it('should get the prompt and prefix', () => { + const chat = makeChat({}, 'test'); + const { prompt, prefix } = chat.getPromptAndPrefix(); + expect(prompt).toBeDefined(); + expect(prefix).toBeDefined(); + }); + + it('should push a transcript', () => { + const transcript: Transcript = []; + const chat = makeChat({}, 'test', transcript); + const testMessage = new AssistantMessage({ think: ['test'], invoke: [] }); + chat.pushMessages(testMessage); + expect(transcript.pop()).toStrictEqual(testMessage); + }); +}); diff --git a/packages/kernel-agents/src/prompt.ts b/packages/kernel-agents/src/prompt.ts new file mode 100644 index 000000000..6fd086501 --- /dev/null +++ b/packages/kernel-agents/src/prompt.ts @@ -0,0 +1,52 @@ +import { extractCapabilitySchemas } from './capability.ts'; +import { exampleTranscripts } from './example-transcripts.ts'; +import { CapabilitySpecMessage, UserMessage } from './messages.ts'; +import type { Transcript } from './messages.ts'; +import type { CapabilityRecord, Chat } from './types.ts'; + +const stringifyTranscript = (transcript: Transcript, index: number): string => + [ + `TRANSCRIPT ${index + 1}: [`, + transcript.map((message) => message.toJSON()).join(', '), + `]`, + ].join(' '); + +export const makeChat = ( + capabilities: CapabilityRecord, + query: string, + transcript: Transcript = [], +): Chat => { + transcript.push( + new CapabilitySpecMessage(extractCapabilitySchemas(capabilities)), + new UserMessage(query), + ); + const transcripts = [...exampleTranscripts, transcript]; + const preamble = [ + `The following are ${transcripts.length} transcripts of conversations between a user and a state-of-the-art capability-augmented assistant.`, + `Each transcript begins with a JSON-formatted list of the assistant's available capabilities, then proceeds to the conversation history, including user messages, assistant capability invocations, and the results of those invocations.`, + `Note that the assistant efficiently invokes capabilities to perform tasks. This reflects that the assistant is intelligent and can reason logically about function composition, and prefers to invoke external capabilities to prove the correctness of its answers.`, + `Also note that, although the assistant does not necessarily use every available capability, it never attempts to use a capability that was not specified prior in the transcript.`, + ].join('\n'); + /** + * The assistant must either immediately invoke a capability, or think and then + * invoke a capability. In either case, the next piece of the transcript must + * begin with this incomplete JSON prefix. + * + * XXX Subtle changes in the prefix can disrupt the tokenized pattern; + * this prompt string is aligned to llama3's implicit tokenizer boundaries. + */ + const responsePrefix = `{"messageType":"assistant","`; + return { + getPromptAndPrefix: () => { + const rawPrompt = [ + preamble, + ...transcripts.map(stringifyTranscript), + ].join('\n\n'); + const prompt = `${rawPrompt.slice(0, rawPrompt.length - 1)}, ${responsePrefix}`; + return { prompt, prefix: responsePrefix }; + }, + pushMessages: (...messages: Transcript) => { + transcript.push(...messages); + }, + }; +}; diff --git a/packages/kernel-agents/src/types.ts b/packages/kernel-agents/src/types.ts new file mode 100644 index 000000000..49e58f8a7 --- /dev/null +++ b/packages/kernel-agents/src/types.ts @@ -0,0 +1,65 @@ +import type { Transcript } from './messages.ts'; + +export type JsonSchema = + | PrimitiveJsonSchema + | ArrayJsonSchema + | ObjectJsonSchemaProperty; + +type PrimitiveJsonSchema = { + type: 'string' | 'number' | 'boolean'; + description?: string; +}; + +type ArrayJsonSchema = { + type: 'array'; + description?: string; + item: JsonSchema; +}; + +type ObjectJsonSchemaProperty = { + type: 'object'; + description?: string; + properties: { + [key: string]: JsonSchema; + }; + required?: string[]; + additionalProperties?: boolean; +}; + +export type Capability, Return = null> = ( + args: Args, +) => Promise; + +export type CapabilitySchema = { + description: string; + args: Record; + returns?: JsonSchema; +}; + +export type ExtractRecordKeys = + Rec extends Record ? Key : never; + +export type CapabilitySpec< + Args extends Record = Record, + Return = void, +> = { + func: Capability; + schema: CapabilitySchema>; +}; + +export type CapabilityRecord = Record< + Keys, + CapabilitySpec +>; + +export type Agent = { + task: ( + prompt: string, + options?: { invocationBudget?: number }, + ) => Promise; +}; + +export type Chat = { + getPromptAndPrefix: () => { prompt: string; prefix: string }; + pushMessages: (...messages: Transcript) => void; +}; diff --git a/packages/kernel-agents/test/constants.ts b/packages/kernel-agents/test/constants.ts new file mode 100644 index 000000000..17d108edf --- /dev/null +++ b/packages/kernel-agents/test/constants.ts @@ -0,0 +1,10 @@ +/** + * Test constants for E2E tests + */ +export const DEFAULT_MODEL = 'llama3.1:latest'; + +/** + * Ollama API endpoints + */ +export const OLLAMA_API_BASE = 'http://localhost:11434'; +export const OLLAMA_TAGS_ENDPOINT = `${OLLAMA_API_BASE}/api/tags`; diff --git a/packages/kernel-agents/test/e2e/agent.test.ts b/packages/kernel-agents/test/e2e/agent.test.ts new file mode 100644 index 000000000..921044d4c --- /dev/null +++ b/packages/kernel-agents/test/e2e/agent.test.ts @@ -0,0 +1,115 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; + +import { Logger } from '@metamask/logger'; +import { OllamaNodejsService } from '@ocap/kernel-language-model-service/ollama/nodejs'; +import { fetchMock } from '@ocap/repo-tools/test-utils/fetch-mock'; +import { + afterAll, + beforeAll, + beforeEach, + describe, + expect, + it, + vi, +} from 'vitest'; + +import { makeAgent } from '../../src/agent.ts'; +import { count, add, multiply } from '../../src/example-capabilities.ts'; +import { DEFAULT_MODEL } from '../constants.ts'; + +/** + * Generate a random letter. + * + * @returns a random letter. + */ +function randomLetter(): string { + return String.fromCharCode(Math.floor(Math.random() * 26) + 97); +} + +const logger = new Logger('test'); + +describe('agent', () => { + beforeAll(() => { + fetchMock.disableMocks(); + }); + + afterAll(() => { + fetchMock.enableMocks(); + }); + + let llmService: OllamaNodejsService; + beforeEach(() => { + llmService = new OllamaNodejsService({ endowments: { fetch } }); + }); + + it( + 'should create an agent and process a request', + { + retry: 3, + timeout: 5_000, + }, + async () => { + const llm = await llmService.makeInstance({ model: DEFAULT_MODEL }); + const agent = makeAgent({ llm, capabilities: {}, logger }); + expect(agent).toBeDefined(); + + const letter = randomLetter().toUpperCase(); + const result = (await agent.task( + `Name an animal that starts with the letter "${letter}"`, + )) as string; + expect(result).toBeDefined(); + expect(result.length).toBeGreaterThan(0); + expect(result).toMatch( + new RegExp(`(${letter}|${letter.toLowerCase()})+`, 'u'), + ); + }, + ); + + it( + 'should create an agent that uses tools', + { + retry: 3, + timeout: 5_000, + }, + async () => { + const llm = await llmService.makeInstance({ model: DEFAULT_MODEL }); + const word = 'xf9147qsdhdkj'; + const countSpy = vi.spyOn(count, 'func'); + const agent = makeAgent({ llm, capabilities: { count }, logger }); + expect(agent).toBeDefined(); + const result = await agent.task( + `What is the length of the word "${word}"?`, + ); + expect(result).toBeDefined(); + expect(countSpy).toHaveBeenCalled(); + expect(result).includes(word.length.toString()); + }, + ); + + it( + 'performs multi-step calculations', + { + retry: 3, + timeout: 10_000, + }, + async () => { + const llm = await llmService.makeInstance({ model: DEFAULT_MODEL }); + const agent = makeAgent({ + llm, + capabilities: { + count, + add, + multiply, + }, + logger, + }); + expect(agent).toBeDefined(); + const [length, width, height] = [11, 47, 63]; + const result = await agent.task( + `A box with length ${length}, width ${width}, and height ${height} have volume V. How many digits are in the numerical value of V?`, + ); + expect(result).toBeDefined(); + expect(result).includes(`${length * width * height}`.length.toString()); + }, + ); +}); diff --git a/packages/kernel-agents/test/e2e/suite.test.ts b/packages/kernel-agents/test/e2e/suite.test.ts new file mode 100644 index 000000000..c407ef46b --- /dev/null +++ b/packages/kernel-agents/test/e2e/suite.test.ts @@ -0,0 +1,40 @@ +import { fetchMock } from '@ocap/repo-tools/test-utils/fetch-mock'; +import { afterAll, beforeAll, describe, expect, it } from 'vitest'; + +import { + DEFAULT_MODEL, + OLLAMA_API_BASE, + OLLAMA_TAGS_ENDPOINT, +} from '../constants.ts'; + +describe.sequential('test suite', () => { + beforeAll(() => { + fetchMock.disableMocks(); + }); + + afterAll(() => { + fetchMock.enableMocks(); + }); + + it(`connects to Ollama instance`, async () => { + const response = await fetch(OLLAMA_API_BASE); + expect(response.ok).toBe(true); + }); + + it(`can access ${DEFAULT_MODEL} model`, async () => { + const response = await fetch(OLLAMA_TAGS_ENDPOINT); + expect(response.ok).toBe(true); + + const data = (await response.json()) as { + models: { name: string }[]; + }; + expect(data?.models).toBeDefined(); + expect(Array.isArray(data.models)).toBe(true); + + const llamaModel = data.models.find( + (foundModel: { name: string }) => foundModel.name === DEFAULT_MODEL, + ); + expect(llamaModel).toBeDefined(); + expect(llamaModel?.name).toBe(DEFAULT_MODEL); + }); +}); diff --git a/packages/kernel-agents/test/utils.ts b/packages/kernel-agents/test/utils.ts new file mode 100644 index 000000000..8870d4ef9 --- /dev/null +++ b/packages/kernel-agents/test/utils.ts @@ -0,0 +1,13 @@ +/** + * Construct a bundle path URL from a bundle name. + * + * @param bundleName - The name of the bundle. + * + * @returns a path string for the named bundle. + */ +export function getBundleSpec(bundleName: string): string { + return new URL( + `../kernel-test/src/vats/${bundleName}.bundle`, + import.meta.url, + ).toString(); +} diff --git a/packages/kernel-agents/tsconfig.build.json b/packages/kernel-agents/tsconfig.build.json new file mode 100644 index 000000000..14b63bcd1 --- /dev/null +++ b/packages/kernel-agents/tsconfig.build.json @@ -0,0 +1,17 @@ +{ + "extends": "../../tsconfig.packages.build.json", + "compilerOptions": { + "baseUrl": "./", + "lib": ["ES2022"], + "outDir": "./dist", + "rootDir": "./src", + "types": ["node"] + }, + "references": [ + { "path": "../kernel-language-model-service/tsconfig.build.json" }, + { "path": "../kernel-utils/tsconfig.build.json" }, + { "path": "../logger/tsconfig.build.json" } + ], + "files": [], + "include": ["./src"] +} diff --git a/packages/kernel-agents/tsconfig.json b/packages/kernel-agents/tsconfig.json new file mode 100644 index 000000000..08d6aabb6 --- /dev/null +++ b/packages/kernel-agents/tsconfig.json @@ -0,0 +1,22 @@ +{ + "extends": "../../tsconfig.packages.json", + "compilerOptions": { + "baseUrl": "./", + "lib": ["ES2022"], + "types": ["vitest", "node"] + }, + "references": [ + { "path": "../kernel-language-model-service" }, + { "path": "../kernel-utils" }, + { "path": "../logger" }, + { "path": "../repo-tools" } + ], + "include": [ + "../../vitest.config.ts", + "./src", + "./test", + "./vite.config.ts", + "./vitest.config.ts", + "./vitest.config.e2e.ts" + ] +} diff --git a/packages/kernel-agents/typedoc.json b/packages/kernel-agents/typedoc.json new file mode 100644 index 000000000..f8eb78ae1 --- /dev/null +++ b/packages/kernel-agents/typedoc.json @@ -0,0 +1,8 @@ +{ + "entryPoints": [], + "excludePrivate": true, + "hideGenerator": true, + "out": "docs", + "tsconfig": "./tsconfig.build.json", + "projectDocuments": ["documents/*.md"] +} diff --git a/packages/kernel-agents/vitest.config.e2e.ts b/packages/kernel-agents/vitest.config.e2e.ts new file mode 100644 index 000000000..67291d043 --- /dev/null +++ b/packages/kernel-agents/vitest.config.e2e.ts @@ -0,0 +1,22 @@ +import { mergeConfig } from '@ocap/repo-tools/vitest-config'; +import { defineConfig, defineProject } from 'vitest/config'; + +import defaultConfig from '../../vitest.config.ts'; + +export default defineConfig((args) => { + return mergeConfig( + args, + defaultConfig, + defineProject({ + test: { + name: 'kernel-agents-e2e', + // E2E test configuration + testTimeout: 30000, + hookTimeout: 10000, + + // Include only E2E tests + include: ['./test/e2e/**/*.test.ts'], + }, + }), + ); +}); diff --git a/packages/kernel-agents/vitest.config.ts b/packages/kernel-agents/vitest.config.ts new file mode 100644 index 000000000..a04eee63f --- /dev/null +++ b/packages/kernel-agents/vitest.config.ts @@ -0,0 +1,19 @@ +import { mergeConfig } from '@ocap/repo-tools/vitest-config'; +import { defineConfig, defineProject } from 'vitest/config'; + +import defaultConfig from '../../vitest.config.ts'; + +export default defineConfig((args) => { + return mergeConfig( + args, + defaultConfig, + defineProject({ + test: { + name: 'kernel-agents', + include: ['src/**/*.test.ts'], + // Exclude E2E setup test from regular test runs + exclude: ['test/e2e'], + }, + }), + ); +}); diff --git a/packages/kernel-language-model-service/package.json b/packages/kernel-language-model-service/package.json index ba731c878..a7964e624 100644 --- a/packages/kernel-language-model-service/package.json +++ b/packages/kernel-language-model-service/package.json @@ -13,6 +13,16 @@ }, "type": "module", "exports": { + ".": { + "import": { + "types": "./dist/index.d.mts", + "default": "./dist/index.mjs" + }, + "require": { + "types": "./dist/index.d.cts", + "default": "./dist/index.cjs" + } + }, "./ollama/nodejs": { "import": { "types": "./dist/ollama/nodejs.d.mts", diff --git a/packages/kernel-language-model-service/src/index.ts b/packages/kernel-language-model-service/src/index.ts new file mode 100644 index 000000000..c20e32944 --- /dev/null +++ b/packages/kernel-language-model-service/src/index.ts @@ -0,0 +1 @@ +export type * from './types.ts'; diff --git a/packages/kernel-language-model-service/src/ollama/base.test.ts b/packages/kernel-language-model-service/src/ollama/base.test.ts index 6661e37b2..361d926ce 100644 --- a/packages/kernel-language-model-service/src/ollama/base.test.ts +++ b/packages/kernel-language-model-service/src/ollama/base.test.ts @@ -175,10 +175,14 @@ describe('OllamaBaseService', () => { makeMockAbortableAsyncIterator([mockResponse]), ); - const result = await instance.sample(prompt, options); - - for await (const chunk of result) { - expect(chunk).toMatchObject(mockResponse); + const { stream, abort } = await instance.sample(prompt, options); + + try { + for await (const chunk of stream) { + expect(chunk).toMatchObject(mockResponse); + } + } finally { + await abort(); } expect(mockClient.generate).toHaveBeenCalledWith({ diff --git a/packages/kernel-language-model-service/src/ollama/base.ts b/packages/kernel-language-model-service/src/ollama/base.ts index c97230211..66ef88fba 100644 --- a/packages/kernel-language-model-service/src/ollama/base.ts +++ b/packages/kernel-language-model-service/src/ollama/base.ts @@ -83,11 +83,14 @@ export class OllamaBaseService ...mandatoryOptions, prompt, }); - return (async function* () { - for await (const chunk of response) { - yield chunk; - } - })(); + return { + stream: (async function* () { + for await (const chunk of response) { + yield chunk; + } + })(), + abort: async () => response.abort(), + }; }, }; return harden(instance); diff --git a/packages/kernel-language-model-service/src/ollama/nodejs.test.ts b/packages/kernel-language-model-service/src/ollama/nodejs.test.ts index aa069dbc9..882dcb1bd 100644 --- a/packages/kernel-language-model-service/src/ollama/nodejs.test.ts +++ b/packages/kernel-language-model-service/src/ollama/nodejs.test.ts @@ -77,13 +77,18 @@ describe('OllamaNodejsService', () => { } as Parameters[0]); const instance = await service.makeInstance({ model }); - for await (const chunk of await instance.sample('Hello, ')) { - expect(chunk).toMatchObject({ - response, - done: true, - // eslint-disable-next-line @typescript-eslint/naming-convention - done_reason: 'stop', - }); + const { stream, abort } = await instance.sample('Hello, '); + try { + for await (const chunk of stream) { + expect(chunk).toMatchObject({ + response, + done: true, + // eslint-disable-next-line @typescript-eslint/naming-convention + done_reason: 'stop', + }); + } + } finally { + await abort(); } }); }); diff --git a/packages/kernel-language-model-service/src/types.ts b/packages/kernel-language-model-service/src/types.ts index b5ed239f5..5895986c4 100644 --- a/packages/kernel-language-model-service/src/types.ts +++ b/packages/kernel-language-model-service/src/types.ts @@ -46,7 +46,10 @@ export type LanguageModel = { sample: ( prompt: string, options?: Partial, - ) => Promise>; + ) => Promise<{ + stream: AsyncIterable; + abort: () => Promise; + }>; }; /** diff --git a/packages/kernel-language-model-service/test/e2e/ollama.test.ts b/packages/kernel-language-model-service/test/e2e/ollama.test.ts index 5946359c6..80b771db8 100644 --- a/packages/kernel-language-model-service/test/e2e/ollama.test.ts +++ b/packages/kernel-language-model-service/test/e2e/ollama.test.ts @@ -54,24 +54,14 @@ describe('OllamaNodejsService E2E', { timeout: 10_000 }, () => { it('should return a streaming result', async () => { const prompt = 'A B C'; let completion = prompt; - const response = await instance.sample(prompt); - let exitEarly = false; - await Promise.all([ - (async () => { - for await (const chunk of response) { - if (exitEarly) { - return; - } - completion += chunk.response; - } - })(), - new Promise((resolve) => - setTimeout(() => { - exitEarly = true; - resolve(undefined); - }), - ), - ]); + const { stream, abort } = await instance.sample(prompt); + try { + for await (const chunk of stream) { + completion += chunk.response; + } + } finally { + await abort(); + } console.debug('@@@ sample: ', completion); expect(completion).toContain(prompt); expect(completion.length).toBeGreaterThan(prompt.length); diff --git a/tsconfig.build.json b/tsconfig.build.json index 42b5218ac..ecf57a430 100644 --- a/tsconfig.build.json +++ b/tsconfig.build.json @@ -3,6 +3,7 @@ "include": [], "references": [ { "path": "./packages/cli/tsconfig.build.json" }, + { "path": "./packages/kernel-agents/tsconfig.build.json" }, { "path": "./packages/kernel-browser-runtime/tsconfig.build.json" }, { "path": "./packages/kernel-errors/tsconfig.build.json" }, { "path": "./packages/kernel-language-model-service/tsconfig.build.json" }, diff --git a/tsconfig.json b/tsconfig.json index b9680cbdc..c7c624374 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -16,6 +16,7 @@ { "path": "./packages/cli" }, { "path": "./packages/create-package" }, { "path": "./packages/extension" }, + { "path": "./packages/kernel-agents" }, { "path": "./packages/kernel-browser-runtime" }, { "path": "./packages/kernel-errors" }, { "path": "./packages/kernel-language-model-service" }, diff --git a/vitest.config.ts b/vitest.config.ts index 83d678d7e..82f21ca83 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -83,6 +83,12 @@ export default defineConfig({ branches: 0, lines: 1.78, }, + 'packages/kernel-agents/**': { + statements: 100, + functions: 100, + branches: 100, + lines: 100, + }, 'packages/kernel-browser-runtime/**': { statements: 71.84, functions: 69.11, diff --git a/yarn.lock b/yarn.lock index c592f0164..d968304c7 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3254,7 +3254,47 @@ __metadata: languageName: unknown linkType: soft -"@ocap/kernel-language-model-service@workspace:packages/kernel-language-model-service": +"@ocap/kernel-agents@workspace:packages/kernel-agents": + version: 0.0.0-use.local + resolution: "@ocap/kernel-agents@workspace:packages/kernel-agents" + dependencies: + "@arethetypeswrong/cli": "npm:^0.17.4" + "@metamask/auto-changelog": "npm:^5.0.1" + "@metamask/eslint-config": "npm:^14.0.0" + "@metamask/eslint-config-nodejs": "npm:^14.0.0" + "@metamask/eslint-config-typescript": "npm:^14.0.0" + "@metamask/kernel-utils": "workspace:^" + "@metamask/logger": "workspace:^" + "@ocap/kernel-language-model-service": "workspace:^" + "@ocap/repo-tools": "workspace:^" + "@ts-bridge/cli": "npm:^0.6.3" + "@ts-bridge/shims": "npm:^0.1.1" + "@types/node": "npm:^22.13.1" + "@typescript-eslint/eslint-plugin": "npm:^8.29.0" + "@typescript-eslint/parser": "npm:^8.29.0" + "@typescript-eslint/utils": "npm:^8.29.0" + "@vitest/eslint-plugin": "npm:^1.3.4" + depcheck: "npm:^1.4.7" + eslint: "npm:^9.23.0" + eslint-config-prettier: "npm:^10.1.1" + eslint-import-resolver-typescript: "npm:^4.3.1" + eslint-plugin-import-x: "npm:^4.10.0" + eslint-plugin-jsdoc: "npm:^50.6.9" + eslint-plugin-n: "npm:^17.17.0" + eslint-plugin-prettier: "npm:^5.2.6" + eslint-plugin-promise: "npm:^7.2.1" + prettier: "npm:^3.5.3" + rimraf: "npm:^6.0.1" + turbo: "npm:^2.5.6" + typedoc: "npm:^0.28.1" + typescript: "npm:~5.8.2" + typescript-eslint: "npm:^8.29.0" + vite: "npm:^7.1.2" + vitest: "npm:^3.2.4" + languageName: unknown + linkType: soft + +"@ocap/kernel-language-model-service@workspace:^, @ocap/kernel-language-model-service@workspace:packages/kernel-language-model-service": version: 0.0.0-use.local resolution: "@ocap/kernel-language-model-service@workspace:packages/kernel-language-model-service" dependencies: