diff --git a/.changeset/custom-method-handlers.md b/.changeset/custom-method-handlers.md new file mode 100644 index 000000000..498d585d2 --- /dev/null +++ b/.changeset/custom-method-handlers.md @@ -0,0 +1,6 @@ +--- +'@modelcontextprotocol/client': minor +'@modelcontextprotocol/server': minor +--- + +Add `setCustomRequestHandler` / `setCustomNotificationHandler` / `sendCustomRequest` / `sendCustomNotification` (plus `remove*` variants) on `Protocol` for non-standard JSON-RPC methods. Restores typed registration for vendor-specific methods (e.g. `mcp-ui/*`) that #1446/#1451 closed off, without reintroducing class-level generics. Handlers share the standard dispatch path (context, cancellation, tasks); a collision guard rejects standard MCP methods. diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index f581c0cb6..76c1307cc 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -377,6 +377,19 @@ Schema to method string mapping: Request/notification params remain fully typed. Remove unused schema imports after migration. +**Custom (non-standard) methods** — vendor extensions or sub-protocols whose method strings are not in the MCP spec — are no longer accepted by `setRequestHandler`/`setNotificationHandler`. Use the `*Custom*` API instead: + +| v1 | v2 | +| ------------------------------------------------------------ | ------------------------------------------------------------------------------ | +| `setRequestHandler(CustomReqSchema, (req, extra) => ...)` | `setCustomRequestHandler('vendor/method', ParamsSchema, (params, ctx) => ...)` | +| `setNotificationHandler(CustomNotifSchema, n => ...)` | `setCustomNotificationHandler('vendor/method', ParamsSchema, params => ...)` | +| `this.request({ method: 'vendor/x', params }, ResultSchema)` | `this.sendCustomRequest('vendor/x', params, ResultSchema)` | +| `this.notification({ method: 'vendor/x', params })` | `this.sendCustomNotification('vendor/x', params)` | +| `class X extends Protocol` | `class X extends Client` (or `Server`), or compose a `Client` instance | + +The v1 schema's `.shape.params` becomes the `ParamsSchema` argument; the `method: z.literal('...')` value becomes the string argument. + + ## 10. Request Handler Context Types `RequestHandlerExtra` → structured context types with nested groups. Rename `extra` → `ctx` in all handler callbacks. diff --git a/docs/migration.md b/docs/migration.md index 14fc719db..8a63a1162 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -382,6 +382,58 @@ Common method string replacements: | `ResourceListChangedNotificationSchema` | `'notifications/resources/list_changed'` | | `PromptListChangedNotificationSchema` | `'notifications/prompts/list_changed'` | +### Custom (non-standard) protocol methods + +In v1, `setRequestHandler` accepted any Zod schema with a `method: z.literal('...')` shape, so vendor-specific methods (e.g. `mcp-ui/initialize`) could be registered the same way as spec methods. The `Protocol` generics widened the +send-side types to match. + +In v2, `setRequestHandler`/`setNotificationHandler` accept only standard MCP method strings, and the class-level send-side generics have been removed. For methods outside the MCP spec, use the dedicated `*Custom*` methods on `Client` and `Server` (inherited from `Protocol`): + +**Before (v1):** + +```typescript +import { Protocol } from '@modelcontextprotocol/sdk/shared/protocol.js'; + +const SearchRequestSchema = z.object({ + method: z.literal('acme/search'), + params: z.object({ query: z.string() }) +}); + +class App extends Protocol { + constructor() { + super(); + this.setRequestHandler(SearchRequestSchema, req => ({ hits: [req.params.query] })); + } + search(query: string) { + return this.request({ method: 'acme/search', params: { query } }, SearchResultSchema); + } +} +``` + +**After (v2):** + +```typescript +import { Client } from '@modelcontextprotocol/client'; + +const SearchParams = z.object({ query: z.string() }); +const SearchResult = z.object({ hits: z.array(z.string()) }); + +class App extends Client { + constructor() { + super({ name: 'app', version: '1.0.0' }); + this.setCustomRequestHandler('acme/search', SearchParams, params => ({ hits: [params.query] })); + } + search(query: string) { + return this.sendCustomRequest('acme/search', { query }, { params: SearchParams, result: SearchResult }); + } +} +``` + +Custom handlers share the same dispatch path as standard handlers — context, cancellation, task delivery, and error wrapping all apply. Passing a `{ params, result }` schema bundle to `sendCustomRequest` (or `{ params }` to `sendCustomNotification`) validates outbound params +before sending and gives typed `params`; passing a bare result schema sends params unvalidated. + +For larger sub-protocols where neither side is semantically an MCP client or server, prefer composition: hold a `Client` (or `Server`) instance, register custom handlers on it, and expose typed facade methods. See `examples/server/src/customMethodExample.ts` and `examples/client/src/customMethodExample.ts` for runnable examples. + ### `Protocol.request()`, `ctx.mcpReq.send()`, and `Client.callTool()` no longer take a schema parameter The public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods no longer accept a Zod result schema argument. The SDK now resolves the correct result schema internally based on the method name. This means you no longer need to import result schemas diff --git a/examples/client/README.md b/examples/client/README.md index 12a2b0d68..8eca78879 100644 --- a/examples/client/README.md +++ b/examples/client/README.md @@ -36,6 +36,7 @@ Most clients expect a server to be running. Start one from [`../server/README.md | Client credentials (M2M) | Machine-to-machine OAuth client credentials example. | [`src/simpleClientCredentials.ts`](src/simpleClientCredentials.ts) | | URL elicitation client | Drives URL-mode elicitation flows (sensitive input in a browser). | [`src/elicitationUrlExample.ts`](src/elicitationUrlExample.ts) | | Task interactive client | Demonstrates task-based execution + interactive server→client requests. | [`src/simpleTaskInteractiveClient.ts`](src/simpleTaskInteractiveClient.ts) | +| Custom (non-standard) methods client | Sends `acme/*` custom requests and handles custom server notifications. | [`src/customMethodExample.ts`](src/customMethodExample.ts) | ## URL elicitation example (server + client) diff --git a/examples/client/src/customMethodExample.ts b/examples/client/src/customMethodExample.ts new file mode 100644 index 000000000..46b8413b4 --- /dev/null +++ b/examples/client/src/customMethodExample.ts @@ -0,0 +1,80 @@ +// Run with: pnpm tsx src/customMethodExample.ts +// +// Demonstrates sending custom (non-standard) requests and receiving custom +// notifications from the server. +// +// The Protocol class exposes sendCustomRequest / setCustomNotificationHandler for +// vendor-specific methods that are not part of the MCP spec. The schema-bundle +// overload of sendCustomRequest gives typed params with pre-send validation. +// +// Pair with: examples/server/src/customMethodExample.ts (start the server first). + +import { Client, ProtocolError, ProtocolErrorCode, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; +import { z } from 'zod'; + +const SearchParamsSchema = z.object({ + query: z.string(), + limit: z.number().int().positive().optional() +}); + +const SearchResultSchema = z.object({ + results: z.array(z.object({ id: z.string(), title: z.string() })), + total: z.number() +}); + +const AnalyticsResultSchema = z.object({ recorded: z.boolean() }); + +const StatusUpdateParamsSchema = z.object({ + status: z.enum(['idle', 'busy', 'error']), + detail: z.string().optional() +}); + +const serverUrl = process.argv[2] ?? 'http://localhost:3000/mcp'; + +async function main(): Promise { + const client = new Client({ name: 'custom-method-client', version: '1.0.0' }); + + // Register handler for custom server→client notifications before connecting. + client.setCustomNotificationHandler('acme/statusUpdate', StatusUpdateParamsSchema, params => { + console.log(`[client] acme/statusUpdate status=${params.status} detail=${params.detail ?? ''}`); + }); + + const transport = new StreamableHTTPClientTransport(new URL(serverUrl)); + await client.connect(transport); + console.log(`[client] connected to ${serverUrl}`); + + // Schema-bundle overload: typed params + pre-send validation, typed result. + const searchResult = await client.sendCustomRequest( + 'acme/search', + { query: 'widgets', limit: 5 }, + { params: SearchParamsSchema, result: SearchResultSchema } + ); + console.log(`[client] acme/search → ${searchResult.total} results, first: "${searchResult.results[0]?.title}"`); + + // Loose overload: bare result schema, untyped params. + const analyticsResult = await client.sendCustomRequest('acme/analytics', { event: 'page_view' }, AnalyticsResultSchema); + console.log(`[client] acme/analytics → recorded=${analyticsResult.recorded}`); + + // Pre-send validation: schema-bundle overload rejects bad params before the round-trip. + try { + await client.sendCustomRequest( + 'acme/search', + { query: 'widgets', limit: 'five' } as unknown as z.output, + { params: SearchParamsSchema, result: SearchResultSchema } + ); + console.error('[client] expected validation error but request succeeded'); + } catch (error) { + const code = error instanceof ProtocolError && error.code === ProtocolErrorCode.InvalidParams ? 'InvalidParams' : 'unknown'; + console.log(`[client] pre-send validation error (expected, ${code}): ${(error as Error).message}`); + } + + await transport.close(); +} + +try { + await main(); +} catch (error) { + console.error('[client] error:', error); + // eslint-disable-next-line unicorn/no-process-exit + process.exit(1); +} diff --git a/examples/server/README.md b/examples/server/README.md index 384e4f2c2..1a217de0e 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -38,6 +38,7 @@ pnpm tsx src/simpleStreamableHttp.ts | Task interactive server | Task-based execution with interactive server→client requests. | [`src/simpleTaskInteractive.ts`](src/simpleTaskInteractive.ts) | | Hono Streamable HTTP server | Streamable HTTP server built with Hono instead of Express. | [`src/honoWebStandardStreamableHttp.ts`](src/honoWebStandardStreamableHttp.ts) | | SSE polling demo server | Legacy SSE server intended for polling demos. | [`src/ssePollingExample.ts`](src/ssePollingExample.ts) | +| Custom (non-standard) methods server | Registers `acme/*` custom request handlers and sends custom notifications. | [`src/customMethodExample.ts`](src/customMethodExample.ts) | ## OAuth demo flags (Streamable HTTP server) diff --git a/examples/server/src/customMethodExample.ts b/examples/server/src/customMethodExample.ts new file mode 100644 index 000000000..6afc1af21 --- /dev/null +++ b/examples/server/src/customMethodExample.ts @@ -0,0 +1,123 @@ +// Run with: pnpm tsx src/customMethodExample.ts +// +// Demonstrates registering handlers for custom (non-standard) request methods +// and sending custom notifications back to the client. +// +// The Protocol class exposes setCustomRequestHandler / sendCustomNotification for +// vendor-specific methods that are not part of the MCP spec. Params are validated +// against user-provided Zod schemas, and handlers receive the same context +// (cancellation, bidirectional send/notify) as standard handlers. +// +// Pair with: examples/client/src/customMethodExample.ts + +import { randomUUID } from 'node:crypto'; + +import { createMcpExpressApp } from '@modelcontextprotocol/express'; +import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; +import { isInitializeRequest, Server } from '@modelcontextprotocol/server'; +import type { Request, Response } from 'express'; +import { z } from 'zod'; + +const SearchParamsSchema = z.object({ + query: z.string(), + limit: z.number().int().positive().optional() +}); + +const AnalyticsParamsSchema = z.object({ + event: z.string(), + properties: z.record(z.string(), z.unknown()).optional() +}); + +const getServer = () => { + const server = new Server({ name: 'custom-method-server', version: '1.0.0' }, { capabilities: {} }); + + server.setCustomRequestHandler('acme/search', SearchParamsSchema, async (params, ctx) => { + console.log(`[server] acme/search query="${params.query}" limit=${params.limit ?? 'unset'} (req ${ctx.mcpReq.id})`); + + // Send a custom server→client notification on the same SSE stream as this response + // (relatedRequestId routes it to the request's stream rather than the standalone SSE stream). + await server.sendCustomNotification( + 'acme/statusUpdate', + { status: 'busy', detail: `searching "${params.query}"` }, + { relatedRequestId: ctx.mcpReq.id } + ); + + return { + results: [ + { id: 'r1', title: `Result for "${params.query}"` }, + { id: 'r2', title: 'Another result' } + ], + total: 2 + }; + }); + + server.setCustomRequestHandler('acme/analytics', AnalyticsParamsSchema, async params => { + console.log(`[server] acme/analytics event="${params.event}"`); + return { recorded: true }; + }); + + return server; +}; + +const PORT = process.env.PORT ? Number.parseInt(process.env.PORT, 10) : 3000; +const app = createMcpExpressApp(); +const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; + +app.post('/mcp', async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + try { + let transport: NodeStreamableHTTPServerTransport; + if (sessionId && transports[sessionId]) { + transport = transports[sessionId]; + } else if (!sessionId && isInitializeRequest(req.body)) { + transport = new NodeStreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: sid => { + transports[sid] = transport; + } + }); + transport.onclose = () => { + const sid = transport.sessionId; + if (sid) delete transports[sid]; + }; + const server = getServer(); + await server.connect(transport); + } else { + res.status(400).json({ jsonrpc: '2.0', error: { code: -32_000, message: 'No valid session ID' }, id: null }); + return; + } + await transport.handleRequest(req, res, req.body); + } catch (error) { + console.error('Error handling MCP request:', error); + if (!res.headersSent) { + res.status(500).json({ jsonrpc: '2.0', error: { code: -32_603, message: 'Internal server error' }, id: null }); + } + } +}); + +const handleSessionRequest = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (!sessionId || !transports[sessionId]) { + res.status(400).send('Invalid or missing session ID'); + return; + } + await transports[sessionId].handleRequest(req, res); +}; + +app.get('/mcp', handleSessionRequest); +app.delete('/mcp', handleSessionRequest); + +app.listen(PORT, error => { + if (error) { + console.error('Failed to start server:', error); + // eslint-disable-next-line unicorn/no-process-exit + process.exit(1); + } + console.log(`Custom-method example server listening on http://localhost:${PORT}/mcp`); + console.log('Custom methods: acme/search, acme/analytics'); +}); + +process.on('SIGINT', async () => { + for (const sid in transports) await transports[sid]!.close(); + process.exit(0); +}); diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 57eab6932..1fa079992 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -40,6 +40,8 @@ import { isJSONRPCNotification, isJSONRPCRequest, isJSONRPCResultResponse, + isNotificationMethod, + isRequestMethod, ProtocolError, ProtocolErrorCode, SUPPORTED_PROTOCOL_VERSIONS @@ -809,7 +811,7 @@ export abstract class Protocol { }; if (!this._transport) { - earlyReject(new Error('Not connected')); + earlyReject(new SdkError(SdkErrorCode.NotConnected, 'Not connected')); return; } @@ -1057,6 +1059,172 @@ export abstract class Protocol { removeNotificationHandler(method: NotificationMethod): void { this._notificationHandlers.delete(method); } + + /** + * Registers a handler for a custom (non-standard) request method. + * + * Unlike {@linkcode setRequestHandler}, this accepts any method + * string and validates incoming params against a user-provided schema instead of an SDK-defined + * one. Capability checks are skipped. The handler receives the same {@linkcode BaseContext | context} + * as standard handlers, including cancellation, task support, and bidirectional send/notify. + */ + setCustomRequestHandler

( + method: string, + paramsSchema: P, + handler: (params: SchemaOutput

, ctx: ContextT) => Result | Promise + ): void { + if (isRequestMethod(method)) { + throw new Error(`"${method}" is a standard MCP request method. Use setRequestHandler() instead.`); + } + this._requestHandlers.set(method, (request, ctx) => { + const { _meta, ...userParams } = (request.params ?? {}) as Record; + void _meta; + const parsed = parseSchema(paramsSchema, userParams); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error.message}`); + } + return Promise.resolve(handler(parsed.data, ctx)); + }); + } + + /** + * Removes a custom request handler previously registered with + * {@linkcode Protocol.setCustomRequestHandler | setCustomRequestHandler}. + */ + removeCustomRequestHandler(method: string): void { + if (isRequestMethod(method)) { + throw new Error(`"${method}" is a standard MCP request method. Use removeRequestHandler() instead.`); + } + this._requestHandlers.delete(method); + } + + /** + * Registers a handler for a custom (non-standard) notification method. + * + * Unlike {@linkcode Protocol.setNotificationHandler | setNotificationHandler}, this accepts any + * method string and validates incoming params against a user-provided schema instead of an + * SDK-defined one. + */ + setCustomNotificationHandler

( + method: string, + paramsSchema: P, + handler: (params: SchemaOutput

) => void | Promise + ): void { + if (isNotificationMethod(method)) { + throw new Error(`"${method}" is a standard MCP notification method. Use setNotificationHandler() instead.`); + } + this._notificationHandlers.set(method, notification => { + const { _meta, ...userParams } = (notification.params ?? {}) as Record; + void _meta; + const parsed = parseSchema(paramsSchema, userParams); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error.message}`); + } + return Promise.resolve(handler(parsed.data)); + }); + } + + /** + * Removes a custom notification handler previously registered with + * {@linkcode Protocol.setCustomNotificationHandler | setCustomNotificationHandler}. + */ + removeCustomNotificationHandler(method: string): void { + if (isNotificationMethod(method)) { + throw new Error(`"${method}" is a standard MCP notification method. Use removeNotificationHandler() instead.`); + } + this._notificationHandlers.delete(method); + } + + /** + * Sends a custom (non-standard) request and waits for a response, validating the result against + * the provided schema. + * + * Unlike {@linkcode Protocol.request | request}, this accepts any method string. Capability + * checks do not apply to custom methods regardless of + * {@linkcode ProtocolOptions.enforceStrictCapabilities}, since + * `assertCapabilityForMethod` only covers + * standard MCP methods. + * + * Pass a `{ params, result }` schema bundle as the third argument to get typed `params` and + * pre-send validation; pass a bare result schema for loose, unvalidated params. + * + * The `params` schema is used only for validation — the value you pass is sent as-is. + * Transforms (e.g. `.trim()`) and defaults (e.g. `.default(n)`) on the schema are not + * applied to outbound data, matching the behavior of {@linkcode Protocol.request | request}. + */ + sendCustomRequest

( + method: string, + params: SchemaOutput

, + schemas: { params: P; result: R }, + options?: RequestOptions + ): Promise>; + sendCustomRequest( + method: string, + params: Record | undefined, + resultSchema: R, + options?: RequestOptions + ): Promise>; + async sendCustomRequest( + method: string, + params: Record | undefined, + schemaOrBundle: AnySchema | { params: AnySchema; result: AnySchema }, + options?: RequestOptions + ): Promise { + let resultSchema: AnySchema; + if (isSchemaBundle(schemaOrBundle)) { + const parsed = parseSchema(schemaOrBundle.params, params); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error.message}`); + } + resultSchema = schemaOrBundle.result; + } else { + resultSchema = schemaOrBundle; + } + return this._requestWithSchema({ method, params } as Request, resultSchema, options); + } + + /** + * Sends a custom (non-standard) notification. + * + * Unlike {@linkcode Protocol.notification | notification}, this accepts any method string. It + * routes through {@linkcode Protocol.notification | notification}, so debouncing and task-queued + * delivery apply. Capability checks are a no-op for custom methods since + * `assertNotificationCapability` only covers + * standard MCP notifications. + * + * Pass a `{ params }` schema bundle as the third argument to get typed `params` and pre-send + * validation. The schema validates only — transforms and defaults are not applied to + * outbound data; the value you pass is sent as-is. + */ + sendCustomNotification

( + method: string, + params: SchemaOutput

, + schemas: { params: P }, + options?: NotificationOptions + ): Promise; + sendCustomNotification(method: string, params?: Record, options?: NotificationOptions): Promise; + async sendCustomNotification( + method: string, + params?: Record, + schemasOrOptions?: { params: AnySchema } | NotificationOptions, + maybeOptions?: NotificationOptions + ): Promise { + let options: NotificationOptions | undefined; + if (schemasOrOptions && 'params' in schemasOrOptions) { + const parsed = parseSchema(schemasOrOptions.params, params); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error.message}`); + } + options = maybeOptions; + } else { + options = schemasOrOptions; + } + return this.notification({ method, params } as Notification, options); + } +} + +function isSchemaBundle(value: AnySchema | { params: AnySchema; result: AnySchema }): value is { params: AnySchema; result: AnySchema } { + return !('~standard' in value) && 'params' in value && 'result' in value; } function isPlainObject(value: unknown): value is Record { diff --git a/packages/core/src/types/schemas.ts b/packages/core/src/types/schemas.ts index 86acf11d7..4743f4f25 100644 --- a/packages/core/src/types/schemas.ts +++ b/packages/core/src/types/schemas.ts @@ -2209,6 +2209,20 @@ const notificationSchemas = buildSchemaMap([...ClientNotificationSchema.options, NotificationSchemaType >; +/** + * Type predicate: returns true if `method` is a standard MCP request method. + */ +export function isRequestMethod(method: string): method is RequestMethod { + return Object.hasOwn(requestSchemas, method); +} + +/** + * Type predicate: returns true if `method` is a standard MCP notification method. + */ +export function isNotificationMethod(method: string): method is NotificationMethod { + return Object.hasOwn(notificationSchemas, method); +} + /** * Gets the Zod schema for a given request method. * The return type is a ZodType that parses to RequestTypeMap[M], allowing callers diff --git a/packages/core/test/shared/customMethods.test.ts b/packages/core/test/shared/customMethods.test.ts new file mode 100644 index 000000000..7cf076a98 --- /dev/null +++ b/packages/core/test/shared/customMethods.test.ts @@ -0,0 +1,275 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import * as z from 'zod/v4'; + +import { SdkError, SdkErrorCode } from '../../src/errors/sdkErrors.js'; +import type { BaseContext } from '../../src/shared/protocol.js'; +import { Protocol } from '../../src/shared/protocol.js'; +import { ProtocolError, ProtocolErrorCode } from '../../src/types/index.js'; +import { InMemoryTransport } from '../../src/util/inMemory.js'; + +class TestProtocol extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + protected buildContext(ctx: BaseContext): BaseContext { + return ctx; + } +} + +async function linkedPair(): Promise<[TestProtocol, TestProtocol]> { + const a = new TestProtocol(); + const b = new TestProtocol(); + const [ta, tb] = InMemoryTransport.createLinkedPair(); + await Promise.all([a.connect(ta), b.connect(tb)]); + return [a, b]; +} + +const SearchParams = z.object({ query: z.string(), limit: z.number().optional() }); +const SearchResult = z.object({ hits: z.array(z.string()), total: z.number() }); +const StatusParams = z.object({ status: z.enum(['idle', 'busy']) }); + +describe('custom request handlers', () => { + let client: TestProtocol; + let server: TestProtocol; + + beforeEach(async () => { + [client, server] = await linkedPair(); + }); + + test('happy path: typed params and result', async () => { + server.setCustomRequestHandler('acme/search', SearchParams, params => { + return { hits: [`result:${params.query}`], total: 1 }; + }); + + const result = await client.sendCustomRequest('acme/search', { query: 'widgets', limit: 5 }, SearchResult); + expect(result.hits).toEqual(['result:widgets']); + expect(result.total).toBe(1); + }); + + test('handler receives full context (signal, mcpReq id)', async () => { + let received: BaseContext | undefined; + server.setCustomRequestHandler('acme/ctx', z.object({}), (_params, ctx) => { + received = ctx; + return {}; + }); + + await client.sendCustomRequest('acme/ctx', {}, z.object({})); + expect(received).toBeDefined(); + expect(received?.mcpReq.signal).toBeInstanceOf(AbortSignal); + expect(received?.mcpReq.id).toBeDefined(); + expect(received?.mcpReq.method).toBe('acme/ctx'); + }); + + test('strict schema: SDK-injected _meta is stripped before validation', async () => { + let receivedQ: string | undefined; + let receivedMeta: unknown; + server.setCustomRequestHandler('acme/strict', z.object({ q: z.string() }).strict(), (params, ctx) => { + receivedQ = params.q; + receivedMeta = ctx.mcpReq._meta; + return {}; + }); + await expect(client.sendCustomRequest('acme/strict', { q: 'hi' }, z.object({}), { onprogress: () => {} })).resolves.toEqual({}); + expect(receivedQ).toBe('hi'); + expect(receivedMeta).toMatchObject({ progressToken: expect.anything() }); + }); + + test('invalid params -> InvalidParams ProtocolError', async () => { + server.setCustomRequestHandler('acme/search', SearchParams, () => ({ hits: [], total: 0 })); + + await expect(client.sendCustomRequest('acme/search', { query: 123 }, SearchResult)).rejects.toSatisfy( + (e: unknown) => e instanceof ProtocolError && e.code === ProtocolErrorCode.InvalidParams + ); + }); + + test('collision guard: throws on standard request method', () => { + expect(() => server.setCustomRequestHandler('ping', z.object({}), () => ({}))).toThrow(/standard MCP request method/); + expect(() => server.setCustomRequestHandler('tools/call', z.object({}), () => ({}))).toThrow(/standard MCP request method/); + expect(() => server.removeCustomRequestHandler('tools/list')).toThrow(/standard MCP request method/); + }); + + test('collision guard: does NOT trigger on Object.prototype keys', () => { + for (const m of ['toString', 'constructor', 'hasOwnProperty', '__proto__']) { + expect(() => server.setCustomRequestHandler(m, z.object({}), () => ({}))).not.toThrow(); + expect(() => server.setCustomNotificationHandler(m, z.object({}), () => {})).not.toThrow(); + } + }); + + test('removeCustomRequestHandler -> subsequent request fails MethodNotFound', async () => { + server.setCustomRequestHandler('acme/search', SearchParams, () => ({ hits: [], total: 0 })); + await client.sendCustomRequest('acme/search', { query: 'x' }, SearchResult); + + server.removeCustomRequestHandler('acme/search'); + await expect(client.sendCustomRequest('acme/search', { query: 'x' }, SearchResult)).rejects.toSatisfy( + (e: unknown) => e instanceof ProtocolError && e.code === ProtocolErrorCode.MethodNotFound + ); + }); + + test('double-register -> last wins', async () => { + server.setCustomRequestHandler('acme/v', z.object({}), () => ({ v: 1 })); + server.setCustomRequestHandler('acme/v', z.object({}), () => ({ v: 2 })); + const result = await client.sendCustomRequest('acme/v', {}, z.object({ v: z.number() })); + expect(result.v).toBe(2); + }); +}); + +describe('custom notification handlers', () => { + let client: TestProtocol; + let server: TestProtocol; + + beforeEach(async () => { + [client, server] = await linkedPair(); + }); + + test('handler invoked with typed params', async () => { + const received: string[] = []; + client.setCustomNotificationHandler('acme/status', StatusParams, params => { + received.push(params.status); + }); + + await server.sendCustomNotification('acme/status', { status: 'busy' }); + await server.sendCustomNotification('acme/status', { status: 'idle' }); + await vi.waitFor(() => expect(received).toEqual(['busy', 'idle'])); + }); + + test('collision guard: throws on standard notification method', () => { + expect(() => client.setCustomNotificationHandler('notifications/cancelled', z.object({}), () => {})).toThrow( + /standard MCP notification method/ + ); + expect(() => client.setCustomNotificationHandler('notifications/progress', z.object({}), () => {})).toThrow( + /standard MCP notification method/ + ); + expect(() => client.removeCustomNotificationHandler('notifications/initialized')).toThrow(/standard MCP notification method/); + }); + + test('removeCustomNotificationHandler -> subsequent notifications not delivered', async () => { + const handler = vi.fn(); + client.setCustomNotificationHandler('acme/status', StatusParams, handler); + await server.sendCustomNotification('acme/status', { status: 'busy' }); + await vi.waitFor(() => expect(handler).toHaveBeenCalledTimes(1)); + + client.removeCustomNotificationHandler('acme/status'); + await server.sendCustomNotification('acme/status', { status: 'idle' }); + // Give the event loop a tick; handler should not be called again. + await new Promise(r => setTimeout(r, 10)); + expect(handler).toHaveBeenCalledTimes(1); + }); + + test('invalid params -> handler not invoked, error surfaced via onerror', async () => { + const handler = vi.fn(); + const errors: Error[] = []; + client.setCustomNotificationHandler('acme/status', StatusParams, handler); + client.onerror = e => errors.push(e); + + await server.sendCustomNotification('acme/status', { status: 'unknown' }); + await vi.waitFor(() => expect(errors.length).toBeGreaterThan(0)); + expect(handler).not.toHaveBeenCalled(); + }); +}); + +describe('sendCustomRequest', () => { + test('not connected -> throws SdkError NotConnected', async () => { + const proto = new TestProtocol(); + await expect(proto.sendCustomRequest('acme/x', {}, z.object({}))).rejects.toSatisfy( + (e: unknown) => e instanceof SdkError && e.code === SdkErrorCode.NotConnected + ); + }); + + test('undefined params accepted', async () => { + const [client, server] = await linkedPair(); + server.setCustomRequestHandler('acme/noargs', z.undefined().or(z.object({})), () => ({ ok: true })); + const result = await client.sendCustomRequest('acme/noargs', undefined, z.object({ ok: z.boolean() })); + expect(result.ok).toBe(true); + }); + + test('result validated against resultSchema', async () => { + const [client, server] = await linkedPair(); + server.setCustomRequestHandler('acme/badresult', z.object({}), () => ({ hits: 'not-an-array', total: 0 })); + await expect(client.sendCustomRequest('acme/badresult', {}, SearchResult)).rejects.toThrow(); + }); + + test('schema bundle overload: typed params and result', async () => { + const [client, server] = await linkedPair(); + server.setCustomRequestHandler('acme/search', SearchParams, p => ({ hits: [p.query], total: 1 })); + const result = await client.sendCustomRequest('acme/search', { query: 'q' }, { params: SearchParams, result: SearchResult }); + expect(result.hits).toEqual(['q']); + }); + + test('schema bundle overload: invalid params rejects InvalidParams before transport', async () => { + const proto = new TestProtocol(); // not connected + // InvalidParams (pre-send validation) — proves it does NOT reach the NotConnected path + await expect( + proto.sendCustomRequest('acme/search', { query: 123 } as unknown as z.output, { + params: SearchParams, + result: SearchResult + }) + ).rejects.toSatisfy((e: unknown) => e instanceof ProtocolError && e.code === ProtocolErrorCode.InvalidParams); + }); + + test('schema bundle overload: params sent as-is (validate-only, no outbound transforms)', async () => { + const [client, server] = await linkedPair(); + const P = z.object({ query: z.string().transform(s => s.trim()), page: z.number() }); + let received: unknown; + server.setCustomRequestHandler('acme/q', z.unknown(), p => { + received = p; + return {}; + }); + await client.sendCustomRequest('acme/q', { query: ' hi ', page: 1 }, { params: P, result: z.object({}) }); + expect(received).toEqual({ query: ' hi ', page: 1 }); + }); +}); + +describe('sendCustomNotification', () => { + test('not connected -> throws SdkError NotConnected', async () => { + const proto = new TestProtocol(); + await expect(proto.sendCustomNotification('acme/x', {})).rejects.toSatisfy( + (e: unknown) => e instanceof SdkError && e.code === SdkErrorCode.NotConnected + ); + }); + + test('delivered to peer with no handler -> no error thrown on sender', async () => { + const [client, server] = await linkedPair(); + const errors: Error[] = []; + client.onerror = e => errors.push(e); + await expect(server.sendCustomNotification('acme/unhandled', { x: 1 })).resolves.toBeUndefined(); + }); + + test('schema bundle overload: invalid params throws InvalidParams before transport', async () => { + const proto = new TestProtocol(); // not connected + await expect( + proto.sendCustomNotification('acme/status', { status: 'bad' } as unknown as z.output, { + params: StatusParams + }) + ).rejects.toSatisfy((e: unknown) => e instanceof ProtocolError && e.code === ProtocolErrorCode.InvalidParams); + }); + + test('schema bundle overload: valid params delivered, options as 4th arg', async () => { + const [client, server] = await linkedPair(); + const received: string[] = []; + client.setCustomNotificationHandler('acme/status', StatusParams, p => { + received.push(p.status); + }); + await server.sendCustomNotification('acme/status', { status: 'busy' }, { params: StatusParams }, {}); + await vi.waitFor(() => expect(received).toEqual(['busy'])); + }); + + test('routes through notification(): debouncing applies to custom methods', async () => { + const a = new TestProtocol({ debouncedNotificationMethods: ['acme/tick'] }); + const b = new TestProtocol(); + const [ta, tb] = InMemoryTransport.createLinkedPair(); + await Promise.all([a.connect(ta), b.connect(tb)]); + + let count = 0; + b.setCustomNotificationHandler('acme/tick', z.undefined().or(z.object({})), () => { + count++; + }); + + // Three synchronous sends should coalesce to one delivery. + void a.sendCustomNotification('acme/tick'); + void a.sendCustomNotification('acme/tick'); + void a.sendCustomNotification('acme/tick'); + await new Promise(r => setTimeout(r, 10)); + expect(count).toBe(1); + }); +});