From 16ba027ec86f3723705820c039f1b2d4d48df37c Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 22 Oct 2025 12:35:02 -0700 Subject: [PATCH 01/84] Add types for tasks --- src/types.ts | 144 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 135 insertions(+), 9 deletions(-) diff --git a/src/types.ts b/src/types.ts index e6d3fe46e..e3aff3831 100644 --- a/src/types.ts +++ b/src/types.ts @@ -5,6 +5,9 @@ export const LATEST_PROTOCOL_VERSION = '2025-06-18'; export const DEFAULT_NEGOTIATED_PROTOCOL_VERSION = '2025-03-26'; export const SUPPORTED_PROTOCOL_VERSIONS = [LATEST_PROTOCOL_VERSION, '2025-03-26', '2024-11-05', '2024-10-07']; +export const TASK_META_KEY = 'modelcontextprotocol.io/task'; +export const RELATED_TASK_META_KEY = 'modelcontextprotocol.io/related-task'; + /* JSON-RPC types */ export const JSONRPC_VERSION = '2.0'; @@ -18,12 +21,46 @@ export const ProgressTokenSchema = z.union([z.string(), z.number().int()]); */ export const CursorSchema = z.string(); +/** + * Task creation metadata, used to ask that the server create a task to represent a request. + */ +export const TaskRequestMetadataSchema = z + .object({ + /** + * The task ID to use as a reference to the created task. + */ + taskId: z.string(), + + /** + * Time in milliseconds to ask to keep task results available after completion. Only used with taskId. + */ + keepAlive: z.number().optional() + }) + .passthrough(); + +/** + * Task association metadata, used to signal which task a message originated from. + */ +export const RelatedTaskMetadataSchema = z + .object({ + taskId: z.string() + }) + .passthrough(); + const RequestMetaSchema = z .object({ /** * If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. */ - progressToken: z.optional(ProgressTokenSchema) + progressToken: z.optional(ProgressTokenSchema), + /** + * If specified, the caller is requesting that the receiver create a task to represent the request. + */ + [TASK_META_KEY]: z.optional(TaskRequestMetadataSchema), + /** + * If specified, this request is related to the provided task. + */ + [RELATED_TASK_META_KEY]: z.optional(RelatedTaskMetadataSchema) }) .passthrough(); @@ -44,7 +81,16 @@ const BaseNotificationParamsSchema = z * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) * for notes on _meta usage. */ - _meta: z.optional(z.object({}).passthrough()) + _meta: z.optional( + z + .object({ + /** + * If specified, this request is related to the provided task. + */ + [RELATED_TASK_META_KEY]: z.optional(RelatedTaskMetadataSchema) + }) + .passthrough() + ) }) .passthrough(); @@ -59,7 +105,16 @@ export const ResultSchema = z * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) * for notes on _meta usage. */ - _meta: z.optional(z.object({}).passthrough()) + _meta: z.optional( + z + .object({ + /** + * If specified, this request is related to the provided task. + */ + [RELATED_TASK_META_KEY]: z.optional(RelatedTaskMetadataSchema) + }) + .passthrough() + ) }) .passthrough(); @@ -440,6 +495,51 @@ export const ProgressNotificationSchema = NotificationSchema.extend({ }) }); +/* Tasks */ +/** + * A pollable state object associated with a request. + */ +export const TaskSchema = z.object({ + taskId: z.string(), + status: z.enum(['submitted', 'working', 'completed', 'failed', 'cancelled', 'unknown']), + keepAlive: z.union([z.number(), z.null()]), + pollFrequency: z.optional(z.number()), + error: z.optional(z.string()) +}); + +/** + * An out-of-band notification used to inform the receiver of a task being created. + */ +export const TaskCreatedNotificationSchema = NotificationSchema.extend({ + method: z.literal('notifications/tasks/created'), + params: BaseNotificationParamsSchema +}); + +/** + * A request to get the state of a specific task. + */ +export const GetTaskRequestSchema = RequestSchema.extend({ + method: z.literal('tasks/get'), + params: BaseRequestParamsSchema.extend({ + taskId: z.string() + }) +}); + +/** + * The response to a tasks/get request. + */ +export const GetTaskResultSchema = ResultSchema.merge(TaskSchema); + +/** + * A request to get the result of a specific task. + */ +export const GetTaskPayloadRequestSchema = RequestSchema.extend({ + method: z.literal('tasks/result'), + params: BaseRequestParamsSchema.extend({ + taskId: z.string() + }) +}); + /* Pagination */ export const PaginatedRequestSchema = RequestSchema.extend({ params: BaseRequestParamsSchema.extend({ @@ -1416,20 +1516,36 @@ export const ClientRequestSchema = z.union([ SubscribeRequestSchema, UnsubscribeRequestSchema, CallToolRequestSchema, - ListToolsRequestSchema + ListToolsRequestSchema, + GetTaskRequestSchema, + GetTaskPayloadRequestSchema ]); export const ClientNotificationSchema = z.union([ CancelledNotificationSchema, ProgressNotificationSchema, InitializedNotificationSchema, - RootsListChangedNotificationSchema + RootsListChangedNotificationSchema, + TaskCreatedNotificationSchema ]); -export const ClientResultSchema = z.union([EmptyResultSchema, CreateMessageResultSchema, ElicitResultSchema, ListRootsResultSchema]); +export const ClientResultSchema = z.union([ + EmptyResultSchema, + CreateMessageResultSchema, + ElicitResultSchema, + ListRootsResultSchema, + GetTaskResultSchema +]); /* Server messages */ -export const ServerRequestSchema = z.union([PingRequestSchema, CreateMessageRequestSchema, ElicitRequestSchema, ListRootsRequestSchema]); +export const ServerRequestSchema = z.union([ + PingRequestSchema, + CreateMessageRequestSchema, + ElicitRequestSchema, + ListRootsRequestSchema, + GetTaskRequestSchema, + GetTaskPayloadRequestSchema +]); export const ServerNotificationSchema = z.union([ CancelledNotificationSchema, @@ -1438,7 +1554,8 @@ export const ServerNotificationSchema = z.union([ ResourceUpdatedNotificationSchema, ResourceListChangedNotificationSchema, ToolListChangedNotificationSchema, - PromptListChangedNotificationSchema + PromptListChangedNotificationSchema, + TaskCreatedNotificationSchema ]); export const ServerResultSchema = z.union([ @@ -1451,7 +1568,8 @@ export const ServerResultSchema = z.union([ ListResourceTemplatesResultSchema, ReadResourceResultSchema, CallToolResultSchema, - ListToolsResultSchema + ListToolsResultSchema, + GetTaskResultSchema ]); export class McpError extends Error { @@ -1550,6 +1668,14 @@ export type PingRequest = Infer; export type Progress = Infer; export type ProgressNotification = Infer; +/* Tasks */ +export type Task = Infer; +export type TaskRequestMetadata = Infer; +export type TaskCreatedNotification = Infer; +export type GetTaskRequest = Infer; +export type GetTaskResult = Infer; +export type GetTaskPayloadRequest = Infer; + /* Pagination */ export type PaginatedRequest = Infer; export type PaginatedResult = Infer; From ecef231013498fdb903ce2d753fa5bdbcea57bbd Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 22 Oct 2025 14:49:03 -0700 Subject: [PATCH 02/84] Implement PendingRequest and basic task API --- src/shared/protocol.ts | 143 +++++++++++++++++++++++++++++++++++++---- src/shared/request.ts | 63 ++++++++++++++++++ 2 files changed, 194 insertions(+), 12 deletions(-) create mode 100644 src/shared/request.ts diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 5cb969418..e447a371c 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -3,6 +3,9 @@ import { CancelledNotificationSchema, ClientCapabilities, ErrorCode, + GetTaskRequest, + GetTaskResultSchema, + GetTaskPayloadRequest, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, @@ -17,16 +20,22 @@ import { Progress, ProgressNotification, ProgressNotificationSchema, + RELATED_TASK_META_KEY, Request, RequestId, Result, ServerCapabilities, RequestMeta, MessageExtraInfo, - RequestInfo + RequestInfo, + TaskCreatedNotificationSchema, + TASK_META_KEY, + GetTaskResult, + TaskRequestMetadata } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; +import { PendingRequest } from './request.js'; /** * Callback for progress notifications. @@ -93,6 +102,11 @@ export type RequestOptions = { * If not specified, there is no maximum total timeout. */ maxTotalTimeout?: number; + + /** + * If provided, augments the request with task metadata to enable call-now, fetch-later execution patterns. + */ + task?: TaskRequestMetadata; } & TransportSendOptions; /** @@ -108,7 +122,11 @@ export type NotificationOptions = { /** * Extra data given to request handlers. */ -export type RequestHandlerExtra = { +export type RequestHandlerExtra< + SendRequestT extends Request, + SendNotificationT extends Notification, + SendResultT extends Result = Result +> = { /** * An abort signal used to communicate if the request was cancelled from the sender's side. */ @@ -152,7 +170,7 @@ export type RequestHandlerExtra>(request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>; + sendRequest: >(request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>; }; /** @@ -176,7 +194,7 @@ export abstract class Protocol) => Promise + (request: JSONRPCRequest, extra: RequestHandlerExtra) => Promise > = new Map(); private _requestHandlerAbortControllers: Map = new Map(); private _notificationHandlers: Map Promise> = new Map(); @@ -184,6 +202,8 @@ export abstract class Protocol = new Map(); private _timeoutInfo: Map = new Map(); private _pendingDebouncedNotifications = new Set(); + private _pendingTaskCreations: Map void; reject: (reason: Error) => void }> = new Map(); + private _requestIdToTaskId: Map = new Map(); /** * Callback for when the connection is closed for any reason. @@ -202,7 +222,10 @@ export abstract class Protocol) => Promise; + fallbackRequestHandler?: ( + request: JSONRPCRequest, + extra: RequestHandlerExtra + ) => Promise; /** * A handler to invoke for any notification types that do not have their own handler installed. @@ -219,6 +242,17 @@ export abstract class Protocol { + const taskId = notification.params?._meta?.[RELATED_TASK_META_KEY]?.taskId; + if (taskId) { + const resolver = this._pendingTaskCreations.get(taskId); + if (resolver) { + resolver.resolve(); + this._pendingTaskCreations.delete(taskId); + } + } + }); + this.setRequestHandler( PingRequestSchema, // Automatic pong by default. @@ -310,10 +344,19 @@ export abstract class Protocol = { + const fullExtra: RequestHandlerExtra = { signal: abortController.signal, sessionId: capturedTransport?.sessionId, _meta: request.params?._meta, @@ -444,6 +487,17 @@ export abstract class Protocol>(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { - const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; + beginRequest>( + request: SendRequestT, + resultSchema: T, + options?: RequestOptions + ): PendingRequest { + const { relatedRequestId, resumptionToken, onresumptiontoken, task } = options ?? {}; + const { taskId, keepAlive } = task ?? {}; - return new Promise((resolve, reject) => { + const promise = new Promise>((resolve, reject) => { if (!this._transport) { reject(new Error('Not connected')); return; @@ -522,6 +581,21 @@ export abstract class Protocol { this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); @@ -575,6 +649,48 @@ export abstract class Protocol>(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { + return this.beginRequest(request, resultSchema, options).result(); + } + + /** + * Waits for a task creation notification with the given taskId. + * Returns a promise that resolves when the notifications/tasks/created notification is received, + * or rejects if the task is cleaned up (e.g., connection closed or request completed). + */ + waitForTaskCreation(taskId: string): Promise { + return new Promise((resolve, reject) => { + this._pendingTaskCreations.set(taskId, { resolve, reject }); + }); + } + + /** + * Gets the current status of a task. + */ + async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { + // @ts-expect-error SendRequestT cannot directly contain GetTaskRequest, but we ensure all type instantiations contain it anyways + return this.request({ method: 'tasks/get', params }, GetTaskResultSchema, options); + } + + /** + * Retrieves the result of a completed task. + */ + async getTaskResult>( + params: GetTaskPayloadRequest['params'], + resultSchema: T, + options?: RequestOptions + ): Promise> { + // @ts-expect-error SendRequestT cannot directly contain GetTaskPayloadRequest, but we ensure all type instantiations contain it anyways + return this.request({ method: 'tasks/result', params }, resultSchema, options); } /** @@ -644,7 +760,10 @@ export abstract class Protocol >( requestSchema: T, - handler: (request: z.infer, extra: RequestHandlerExtra) => SendResultT | Promise + handler: ( + request: z.infer, + extra: RequestHandlerExtra + ) => SendResultT | Promise ): void { const method = requestSchema.shape.method.value; this.assertRequestHandlerCapability(method); diff --git a/src/shared/request.ts b/src/shared/request.ts new file mode 100644 index 000000000..eda46ccd1 --- /dev/null +++ b/src/shared/request.ts @@ -0,0 +1,63 @@ +import { ZodType } from 'zod'; +import { Protocol } from './protocol.js'; +import { Request, Notification, Result, Task, GetTaskResult } from '../types.js'; + +const DEFAULT_POLLING_INTERNAL = 5000; + +export interface TaskHandlerOptions { + onTaskStatus: (task: Task) => Promise; +} + +export class PendingRequest { + constructor( + readonly protocol: Protocol, + readonly resultHandle: Promise, + readonly resultSchema: ZodType, + readonly taskId?: string + ) {} + + /** + * Waits for a result, calling onTaskStatus if provided and a task was created. + */ + async result(options?: Partial): Promise { + if (!options?.onTaskStatus || !this.taskId) { + // No task listener or task ID provided, just block for the result + return await this.resultHandle; + } + + // Whichever is successful first (or a failure if all fail) is returned. + return Promise.allSettled([ + this.resultHandle, + (async () => { + // Blocks for a notifications/tasks/created with the provided task ID + await this.protocol.waitForTaskCreation(this.taskId!); + return await this.taskHandler(options as TaskHandlerOptions); + })() + ]).then(([result, task]) => { + if (result.status === 'fulfilled') { + return result.value; + } else if (task.status === 'fulfilled') { + return task.value; + } + + const errors: unknown[] = [result.reason, task.reason]; + throw new Error(`Both request and task handler failed: ${errors.map(e => `${e}`).join(', ')}`); + }); + } + + /** + * Encapsulates polling for a result, calling onTaskStatus after querying the task. + */ + private async taskHandler({ onTaskStatus }: TaskHandlerOptions): Promise { + // Poll for completion + let task: GetTaskResult; + do { + task = await this.protocol.getTask({ taskId: this.taskId! }); + await onTaskStatus(task); + await new Promise(resolve => setTimeout(resolve, task.pollFrequency ?? DEFAULT_POLLING_INTERNAL)); + } while (!(['complete', 'failed', 'cancelled', 'unknown'] as (typeof task.status)[]).includes(task.status)); + + // Process result + return await this.protocol.getTaskResult({ taskId: this.taskId! }, this.resultSchema); + } +} From 41f212486061259b7ce48f0f9d07aec958554a95 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 22 Oct 2025 14:54:12 -0700 Subject: [PATCH 03/84] Implement RelatedTask metadata sends --- src/shared/protocol.ts | 64 ++++++++++++++++++++++++++++++++++++++---- src/types.ts | 1 + 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index e447a371c..3f528e454 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -31,7 +31,8 @@ import { TaskCreatedNotificationSchema, TASK_META_KEY, GetTaskResult, - TaskRequestMetadata + TaskRequestMetadata, + RelatedTaskMetadata } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; @@ -107,6 +108,11 @@ export type RequestOptions = { * If provided, augments the request with task metadata to enable call-now, fetch-later execution patterns. */ task?: TaskRequestMetadata; + + /** + * If provided, associates this request with a related task. + */ + relatedTask?: RelatedTaskMetadata; } & TransportSendOptions; /** @@ -117,6 +123,11 @@ export type NotificationOptions = { * May be used to indicate to the transport which incoming request to associate this outgoing notification with. */ relatedRequestId?: RequestId; + + /** + * If provided, associates this notification with a related task. + */ + relatedTask?: RelatedTaskMetadata; }; /** @@ -548,7 +559,7 @@ export abstract class Protocol { - const { relatedRequestId, resumptionToken, onresumptiontoken, task } = options ?? {}; + const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {}; const { taskId, keepAlive } = task ?? {}; const promise = new Promise>((resolve, reject) => { @@ -596,6 +607,17 @@ export abstract class Protocol { this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); @@ -705,8 +727,9 @@ export abstract class Protocol this._onerror(error)); @@ -741,11 +779,25 @@ export abstract class Protocol; /* Tasks */ export type Task = Infer; export type TaskRequestMetadata = Infer; +export type RelatedTaskMetadata = Infer; export type TaskCreatedNotification = Infer; export type GetTaskRequest = Infer; export type GetTaskResult = Infer; From a8fabb61e194d1d1f1d1184e05b3da1167220efe Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 22 Oct 2025 15:45:11 -0700 Subject: [PATCH 04/84] Implement task state management --- src/shared/protocol.ts | 122 +++++++++++++++++++++++++++++++++++++++-- src/shared/request.ts | 8 +-- src/shared/task.ts | 51 +++++++++++++++++ src/types.ts | 6 +- 4 files changed, 174 insertions(+), 13 deletions(-) create mode 100644 src/shared/task.ts diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 3f528e454..292d8c902 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -4,8 +4,10 @@ import { ClientCapabilities, ErrorCode, GetTaskRequest, + GetTaskRequestSchema, GetTaskResultSchema, GetTaskPayloadRequest, + GetTaskPayloadRequestSchema, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, @@ -31,12 +33,13 @@ import { TaskCreatedNotificationSchema, TASK_META_KEY, GetTaskResult, - TaskRequestMetadata, + TaskMetadata, RelatedTaskMetadata } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; import { PendingRequest } from './request.js'; +import { TaskStore } from './task.js'; /** * Callback for progress notifications. @@ -62,6 +65,11 @@ export type ProtocolOptions = { * e.g., ['notifications/tools/list_changed'] */ debouncedNotificationMethods?: string[]; + /** + * Optional task storage implementation. If provided, the implementation will automatically + * handle task creation, status tracking, and result storage. + */ + taskStore?: TaskStore; }; /** @@ -107,7 +115,7 @@ export type RequestOptions = { /** * If provided, augments the request with task metadata to enable call-now, fetch-later execution patterns. */ - task?: TaskRequestMetadata; + task?: TaskMetadata; /** * If provided, associates this request with a related task. @@ -215,6 +223,7 @@ export abstract class Protocol(); private _pendingTaskCreations: Map void; reject: (reason: Error) => void }> = new Map(); private _requestIdToTaskId: Map = new Map(); + private _taskStore?: TaskStore; /** * Callback for when the connection is closed for any reason. @@ -245,8 +254,7 @@ export abstract class Protocol { - const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); - controller?.abort(notification.params.reason); + this._oncancel(notification); }); this.setNotificationHandler(ProgressNotificationSchema, notification => { @@ -269,6 +277,65 @@ export abstract class Protocol ({}) as SendResultT ); + + // Install task handlers if TaskStore is provided + this._taskStore = _options?.taskStore; + if (this._taskStore) { + this.setRequestHandler(GetTaskRequestSchema, async request => { + const task = await this._taskStore!.getTask(request.params.taskId); + if (!task) { + throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + } + + // @ts-expect-error SendResultT cannot contain GetTaskResult, but we include it in our derived types everywhere else + return { + ...task, + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: request.params.taskId + } + } + } as SendResultT; + }); + + this.setRequestHandler(GetTaskPayloadRequestSchema, async request => { + const task = await this._taskStore!.getTask(request.params.taskId); + if (!task) { + throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + } + + if (task.status !== 'completed') { + throw new McpError(ErrorCode.InvalidParams, `Cannot retrieve result: Task status is '${task.status}', not 'completed'`); + } + + const result = await this._taskStore!.getTaskResult(request.params.taskId); + return { + ...result, + _meta: { + ...result._meta, + [RELATED_TASK_META_KEY]: { + taskId: request.params.taskId + } + } + } as SendResultT; + }); + } + } + + private async _oncancel(notification: z.infer): Promise { + // Handle request cancellation + const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); + controller?.abort(notification.params.reason); + + // If this request had a task, mark it as cancelled in storage + const taskId = this._requestIdToTaskId.get(Number(notification.params.requestId)); + if (taskId && this._taskStore) { + try { + await this._taskStore.updateTaskStatus(taskId, 'cancelled'); + } catch (error) { + this._onerror(new Error(`Failed to cancel task ${taskId}: ${error}`)); + } + } } private _setupTimeout( @@ -429,16 +496,59 @@ export abstract class Protocol handler(request, fullExtra)) .then( - result => { + async result => { if (abortController.signal.aborted) { return; } - return capturedTransport?.send({ + // If this request asked for task creation, create the task and send notification + const taskMetadata = request.params?._meta?.[TASK_META_KEY]; + if (taskMetadata && this._taskStore) { + const task = await this._taskStore!.getTask(taskMetadata.taskId); + if (task) { + throw new McpError(ErrorCode.InvalidParams, `Task ID already exists: ${taskMetadata.taskId}`); + } + + try { + await this._taskStore.createTask(taskMetadata, request.id, { + method: request.method, + params: request.params + }); + + // Send task created notification + await this.notification( + { + method: 'notifications/tasks/created', + params: { + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: taskMetadata.taskId + } + } + } + } as SendNotificationT, + { relatedRequestId: request.id } + ); + } catch (error) { + this._onerror(new Error(`Failed to create task: ${error}`)); + } + } + + // Send the response + await capturedTransport?.send({ result, jsonrpc: '2.0', id: request.id }); + + // Store the result if this was a task-based request + if (taskMetadata && this._taskStore) { + try { + await this._taskStore.storeTaskResult(taskMetadata.taskId, result); + } catch (error) { + this._onerror(new Error(`Failed to store task result: ${error}`)); + } + } }, error => { if (abortController.signal.aborted) { diff --git a/src/shared/request.ts b/src/shared/request.ts index eda46ccd1..26186f2d9 100644 --- a/src/shared/request.ts +++ b/src/shared/request.ts @@ -31,7 +31,7 @@ export class PendingRequest { // Blocks for a notifications/tasks/created with the provided task ID await this.protocol.waitForTaskCreation(this.taskId!); - return await this.taskHandler(options as TaskHandlerOptions); + return await this.taskHandler(this.taskId!, options as TaskHandlerOptions); })() ]).then(([result, task]) => { if (result.status === 'fulfilled') { @@ -48,16 +48,16 @@ export class PendingRequest { + private async taskHandler(taskId: string, { onTaskStatus }: TaskHandlerOptions): Promise { // Poll for completion let task: GetTaskResult; do { - task = await this.protocol.getTask({ taskId: this.taskId! }); + task = await this.protocol.getTask({ taskId: taskId }); await onTaskStatus(task); await new Promise(resolve => setTimeout(resolve, task.pollFrequency ?? DEFAULT_POLLING_INTERNAL)); } while (!(['complete', 'failed', 'cancelled', 'unknown'] as (typeof task.status)[]).includes(task.status)); // Process result - return await this.protocol.getTaskResult({ taskId: this.taskId! }, this.resultSchema); + return await this.protocol.getTaskResult({ taskId: taskId }, this.resultSchema); } } diff --git a/src/shared/task.ts b/src/shared/task.ts new file mode 100644 index 000000000..0a4b52560 --- /dev/null +++ b/src/shared/task.ts @@ -0,0 +1,51 @@ +import { Task, TaskMetadata, Request, RequestId, Result } from '../types.js'; + +/** + * Interface for storing and retrieving task state and results. + * + * Similar to Transport, this allows pluggable task storage implementations + * (in-memory, database, distributed cache, etc.). + */ +export interface TaskStore { + /** + * Creates a new task with the given metadata and original request. + * + * @param task - The task creation metadata from the request + * @param requestId - The JSON-RPC request ID + * @param request - The original request that triggered task creation + */ + createTask(task: TaskMetadata, requestId: RequestId, request: Request): Promise; + + /** + * Gets the current status of a task. + * + * @param taskId - The task identifier + * @returns The task state including status, keepAlive, pollFrequency, and optional error + */ + getTask(taskId: string): Promise; + + /** + * Stores the result of a completed task. + * + * @param taskId - The task identifier + * @param result - The result to store + */ + storeTaskResult(taskId: string, result: Result): Promise; + + /** + * Retrieves the stored result of a task. + * + * @param taskId - The task identifier + * @returns The stored result + */ + getTaskResult(taskId: string): Promise; + + /** + * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). + * + * @param taskId - The task identifier + * @param status - The new status + * @param error - Optional error message if status is 'failed' or 'cancelled' + */ + updateTaskStatus(taskId: string, status: Task['status'], error?: string): Promise; +} diff --git a/src/types.ts b/src/types.ts index 419310f63..23caa5f41 100644 --- a/src/types.ts +++ b/src/types.ts @@ -24,7 +24,7 @@ export const CursorSchema = z.string(); /** * Task creation metadata, used to ask that the server create a task to represent a request. */ -export const TaskRequestMetadataSchema = z +export const TaskMetadataSchema = z .object({ /** * The task ID to use as a reference to the created task. @@ -56,7 +56,7 @@ const RequestMetaSchema = z /** * If specified, the caller is requesting that the receiver create a task to represent the request. */ - [TASK_META_KEY]: z.optional(TaskRequestMetadataSchema), + [TASK_META_KEY]: z.optional(TaskMetadataSchema), /** * If specified, this request is related to the provided task. */ @@ -1670,7 +1670,7 @@ export type ProgressNotification = Infer; /* Tasks */ export type Task = Infer; -export type TaskRequestMetadata = Infer; +export type TaskMetadata = Infer; export type RelatedTaskMetadata = Infer; export type TaskCreatedNotification = Infer; export type GetTaskRequest = Infer; From b3420b3725d6d6db98d760162ddb1c9dd941b27f Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 22 Oct 2025 15:59:52 -0700 Subject: [PATCH 05/84] Attach related task metadata to request handler --- src/shared/protocol.ts | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 292d8c902..8ee3d493b 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -481,12 +481,19 @@ export abstract class Protocol = { signal: abortController.signal, sessionId: capturedTransport?.sessionId, _meta: request.params?._meta, - sendNotification: notification => this.notification(notification, { relatedRequestId: request.id }), - sendRequest: (r, resultSchema, options?) => this.request(r, resultSchema, { ...options, relatedRequestId: request.id }), + sendNotification: async notification => { + const relatedTask = taskMetadata ? { taskId: taskMetadata.taskId } : undefined; + await this.notification(notification, { relatedRequestId: request.id, relatedTask }); + }, + sendRequest: async (r, resultSchema, options?) => { + const relatedTask = taskMetadata ? { taskId: taskMetadata.taskId } : undefined; + return await this.request(r, resultSchema, { ...options, relatedRequestId: request.id, relatedTask }); + }, authInfo: extra?.authInfo, requestId: request.id, requestInfo: extra?.requestInfo @@ -502,7 +509,6 @@ export abstract class Protocol Date: Thu, 23 Oct 2025 12:26:50 -0700 Subject: [PATCH 06/84] Create task before calling handler --- src/client/index.ts | 19 +++++++++ src/shared/protocol.ts | 90 ++++++++++++++++++++++++++---------------- src/shared/request.ts | 35 ++++++++++------ src/shared/task.ts | 11 ++++++ 4 files changed, 108 insertions(+), 47 deletions(-) diff --git a/src/client/index.ts b/src/client/index.ts index 856eb18e5..b000088e6 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,5 +1,6 @@ import { mergeCapabilities, Protocol, ProtocolOptions, RequestOptions } from '../shared/protocol.js'; import { Transport } from '../shared/transport.js'; +import { PendingRequest } from '../shared/request.js'; import { CallToolRequest, CallToolResultSchema, @@ -326,6 +327,24 @@ export class Client< return this.request({ method: 'resources/unsubscribe', params }, EmptyResultSchema, options); } + /** + * Begins a tool call and returns a PendingRequest for granular control over task-based execution. + * + * This is useful when you want to create a task for a long-running tool call and poll for results later. + */ + beginCallTool( + params: CallToolRequest['params'], + resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, + options?: RequestOptions + ): PendingRequest { + return this.beginRequest({ method: 'tools/call', params }, resultSchema, options); + } + + /** + * Calls a tool and waits for the result. Automatically validates structured output if the tool has an outputSchema. + * + * For task-based execution with granular control, use beginCallTool() instead. + */ async callTool( params: CallToolRequest['params'], resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 8ee3d493b..6bbdbf84d 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -501,44 +501,59 @@ export abstract class Protocol handler(request, fullExtra)) - .then( - async result => { - if (abortController.signal.aborted) { - return; + .then(async () => { + // If this request asked for task creation, create the task and send notification + if (taskMetadata && this._taskStore) { + const task = await this._taskStore!.getTask(taskMetadata.taskId); + if (task) { + throw new McpError(ErrorCode.InvalidParams, `Task ID already exists: ${taskMetadata.taskId}`); } - // If this request asked for task creation, create the task and send notification - if (taskMetadata && this._taskStore) { - const task = await this._taskStore!.getTask(taskMetadata.taskId); - if (task) { - throw new McpError(ErrorCode.InvalidParams, `Task ID already exists: ${taskMetadata.taskId}`); - } - - try { - await this._taskStore.createTask(taskMetadata, request.id, { - method: request.method, - params: request.params - }); - - // Send task created notification - await this.notification( - { - method: 'notifications/tasks/created', - params: { - _meta: { - [RELATED_TASK_META_KEY]: { - taskId: taskMetadata.taskId - } + try { + await this._taskStore.createTask(taskMetadata, request.id, { + method: request.method, + params: request.params + }); + + // Send task created notification + await this.notification( + { + method: 'notifications/tasks/created', + params: { + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: taskMetadata.taskId } } - } as SendNotificationT, - { relatedRequestId: request.id } - ); + } + } as SendNotificationT, + { relatedRequestId: request.id } + ); + } catch (error) { + throw new McpError(ErrorCode.InternalError, `Failed to create task: ${taskMetadata.taskId}`); + } + } + }) + .then(async () => { + // If this request had a task, mark it as working + if (taskMetadata && this._taskStore) { + try { + await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'working'); + } catch (error) { + try { + await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'failed', 'Failed to mark task as working'); } catch (error) { - this._onerror(new Error(`Failed to create task: ${error}`)); + throw new McpError(ErrorCode.InternalError, `Failed to mark task as working: ${error}`); } } + } + }) + .then(() => handler(request, fullExtra)) + .then( + async result => { + if (abortController.signal.aborted) { + return; + } // Send the response await capturedTransport?.send({ @@ -552,7 +567,7 @@ export abstract class Protocol>((resolve, reject) => { + // For tasks, create an advance promise for the creation notification to avoid + // race conditions with installing this callback. + const taskCreated = taskId ? this.waitForTaskCreation(taskId) : Promise.resolve(); + + // Send the request + const result = new Promise>((resolve, reject) => { if (!this._transport) { reject(new Error('Not connected')); return; @@ -788,7 +808,7 @@ export abstract class Protocol { + private waitForTaskCreation(taskId: string): Promise { return new Promise((resolve, reject) => { this._pendingTaskCreations.set(taskId, { resolve, reject }); }); diff --git a/src/shared/request.ts b/src/shared/request.ts index 26186f2d9..fae4f1332 100644 --- a/src/shared/request.ts +++ b/src/shared/request.ts @@ -1,16 +1,21 @@ import { ZodType } from 'zod'; import { Protocol } from './protocol.js'; import { Request, Notification, Result, Task, GetTaskResult } from '../types.js'; +import { isTerminal } from './task.js'; const DEFAULT_POLLING_INTERNAL = 5000; +const DEFAULT_HANDLER = () => Promise.resolve(); + export interface TaskHandlerOptions { - onTaskStatus: (task: Task) => Promise; + onTaskCreated: () => Promise | void; + onTaskStatus: (task: Task) => Promise | void; } export class PendingRequest { constructor( readonly protocol: Protocol, + readonly taskCreatedHandle: Promise, readonly resultHandle: Promise, readonly resultSchema: ZodType, readonly taskId?: string @@ -20,24 +25,30 @@ export class PendingRequest): Promise { - if (!options?.onTaskStatus || !this.taskId) { - // No task listener or task ID provided, just block for the result + const { onTaskCreated = DEFAULT_HANDLER, onTaskStatus = DEFAULT_HANDLER } = options ?? {}; + + if (!this.taskId) { + // No task ID provided, just block for the result return await this.resultHandle; } // Whichever is successful first (or a failure if all fail) is returned. return Promise.allSettled([ - this.resultHandle, (async () => { // Blocks for a notifications/tasks/created with the provided task ID - await this.protocol.waitForTaskCreation(this.taskId!); - return await this.taskHandler(this.taskId!, options as TaskHandlerOptions); - })() - ]).then(([result, task]) => { - if (result.status === 'fulfilled') { - return result.value; - } else if (task.status === 'fulfilled') { + await this.taskCreatedHandle; + await onTaskCreated(); + return await this.taskHandler(this.taskId!, { + onTaskCreated, + onTaskStatus + }); + })(), + this.resultHandle + ]).then(([task, result]) => { + if (task.status === 'fulfilled') { return task.value; + } else if (result.status === 'fulfilled') { + return result.value; } const errors: unknown[] = [result.reason, task.reason]; @@ -55,7 +66,7 @@ export class PendingRequest setTimeout(resolve, task.pollFrequency ?? DEFAULT_POLLING_INTERNAL)); - } while (!(['complete', 'failed', 'cancelled', 'unknown'] as (typeof task.status)[]).includes(task.status)); + } while (!isTerminal(task.status)); // Process result return await this.protocol.getTaskResult({ taskId: taskId }, this.resultSchema); diff --git a/src/shared/task.ts b/src/shared/task.ts index 0a4b52560..617ab81aa 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -49,3 +49,14 @@ export interface TaskStore { */ updateTaskStatus(taskId: string, status: Task['status'], error?: string): Promise; } + +/** + * Checks if a task status represents a terminal state. + * Terminal states are those where the task has finished and will not change. + * + * @param status - The task status to check + * @returns True if the status is terminal (completed, failed, cancelled, or unknown) + */ +export function isTerminal(status: Task['status']): boolean { + return status === 'completed' || status === 'failed' || status === 'cancelled' || status === 'unknown'; +} From fcd2882303df6a332b1e4cbaef784db250875f30 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 23 Oct 2025 12:27:13 -0700 Subject: [PATCH 07/84] Create task example --- src/examples/client/simpleStreamableHttp.ts | 75 +++++++++++ src/examples/server/simpleStreamableHttp.ts | 32 ++++- src/examples/shared/inMemoryTaskStore.ts | 142 ++++++++++++++++++++ 3 files changed, 248 insertions(+), 1 deletion(-) create mode 100644 src/examples/shared/inMemoryTaskStore.ts diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index 10f6afcbe..697353ef4 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -58,6 +58,7 @@ function printHelp(): void { console.log(' reconnect - Reconnect to the server'); console.log(' list-tools - List available tools'); console.log(' call-tool [args] - Call a tool with optional JSON arguments'); + console.log(' call-tool-task [args] - Call a tool with task-based execution (example: call-tool-task delay {"duration":3000})'); console.log(' greet [name] - Call the greet tool'); console.log(' multi-greet [name] - Call the multi-greet tool with notifications'); console.log(' collect-info [type] - Test elicitation with collect-user-info tool (contact/preferences/feedback)'); @@ -141,6 +142,23 @@ function commandLoop(): void { break; } + case 'call-tool-task': + if (args.length < 2) { + console.log('Usage: call-tool-task [args]'); + } else { + const toolName = args[1]; + let toolArgs = {}; + if (args.length > 2) { + try { + toolArgs = JSON.parse(args.slice(2).join(' ')); + } catch { + console.log('Invalid JSON arguments. Using empty args.'); + } + } + await callToolTask(toolName, toolArgs); + } + break; + case 'list-prompts': await listPrompts(); break; @@ -777,6 +795,63 @@ async function readResource(uri: string): Promise { } } +async function callToolTask(name: string, args: Record): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + console.log(`Calling tool '${name}' with task-based execution...`); + console.log('Arguments:', args); + + // Use task-based execution - call now, fetch later + const taskId = `task-${Date.now()}`; + console.log(`Task ID: ${taskId}`); + console.log('This will return immediately while processing continues in the background...'); + + try { + // Begin the tool call with task metadata + const pendingRequest = client.beginCallTool( + { + name, + arguments: args + }, + CallToolResultSchema, + { + task: { + taskId, + keepAlive: 60000 // Keep results for 60 seconds + } + } + ); + + console.log('Waiting for task completion...'); + + await pendingRequest.result({ + onTaskCreated: () => { + console.log('Task created successfully'); + }, + onTaskStatus: task => { + console.log(` ${task.status}${task.error ? ` - ${task.error}` : ''}`); + } + }); + + console.log('Task completed! Fetching result...'); + + // Get the actual result + const result = await client.getTaskResult({ taskId }, CallToolResultSchema); + + console.log('Tool result:'); + result.content.forEach(item => { + if (item.type === 'text') { + console.log(` ${item.text}`); + } + }); + } catch (error) { + console.log(`Error with task-based execution: ${error}`); + } +} + async function cleanup(): Promise { if (client && transport) { try { diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 5872cb4ac..966337f45 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -14,6 +14,7 @@ import { ResourceLink } from '../../types.js'; import { InMemoryEventStore } from '../shared/inMemoryEventStore.js'; +import { InMemoryTaskStore } from '../shared/inMemoryTaskStore.js'; import { setupAuthServer } from './demoInMemoryOAuthProvider.js'; import { OAuthMetadata } from 'src/shared/auth.js'; import { checkResourceAllowed } from 'src/shared/auth-utils.js'; @@ -24,6 +25,9 @@ import cors from 'cors'; const useOAuth = process.argv.includes('--oauth'); const strictOAuth = process.argv.includes('--oauth-strict'); +// Create shared task store for demonstration +const taskStore = new InMemoryTaskStore(); + // Create an MCP server with implementation details const getServer = () => { const server = new McpServer( @@ -33,7 +37,10 @@ const getServer = () => { icons: [{ src: './mcp.svg', sizes: ['512x512'], mimeType: 'image/svg+xml' }], websiteUrl: 'https://github.com/modelcontextprotocol/typescript-sdk' }, - { capabilities: { logging: {} } } + { + capabilities: { logging: {} }, + taskStore // Enable task support + } ); // Register a simple tool that returns a greeting @@ -439,6 +446,29 @@ const getServer = () => { } ); + // Register a long-running tool that demonstrates task execution + server.registerTool( + 'delay', + { + title: 'Delay', + description: 'A simple tool that delays for a specified duration, useful for testing task execution', + inputSchema: { + duration: z.number().describe('Duration in milliseconds').default(5000) + } + }, + async ({ duration }): Promise => { + await new Promise(resolve => setTimeout(resolve, duration)); + return { + content: [ + { + type: 'text', + text: `Completed ${duration}ms delay` + } + ] + }; + } + ); + return server; }; diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts new file mode 100644 index 000000000..79d8a05bd --- /dev/null +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -0,0 +1,142 @@ +import { Task, TaskMetadata, Request, RequestId, Result } from '../../types.js'; +import { TaskStore, isTerminal } from '../../shared/task.js'; + +interface StoredTask { + task: Task; + request: Request; + requestId: RequestId; + result?: Result; +} + +/** + * A simple in-memory implementation of TaskStore for demonstration purposes. + * + * This implementation stores all tasks in memory and provides automatic cleanup + * based on the keepAlive duration specified in the task metadata. + * + * Note: This is not suitable for production use as all data is lost on restart. + * For production, consider implementing TaskStore with a database or distributed cache. + */ +export class InMemoryTaskStore implements TaskStore { + private tasks = new Map(); + private cleanupTimers = new Map>(); + + async createTask(metadata: TaskMetadata, requestId: RequestId, request: Request): Promise { + const taskId = metadata.taskId; + + if (this.tasks.has(taskId)) { + throw new Error(`Task with ID ${taskId} already exists`); + } + + const task: Task = { + taskId, + status: 'submitted', + keepAlive: metadata.keepAlive ?? null, + pollFrequency: 500 + }; + + this.tasks.set(taskId, { + task, + request, + requestId + }); + + // Schedule cleanup if keepAlive is specified + if (metadata.keepAlive) { + const timer = setTimeout(() => { + this.tasks.delete(taskId); + this.cleanupTimers.delete(taskId); + }, metadata.keepAlive); + + this.cleanupTimers.set(taskId, timer); + } + } + + async getTask(taskId: string): Promise { + const stored = this.tasks.get(taskId); + return stored ? { ...stored.task } : null; + } + + async storeTaskResult(taskId: string, result: Result): Promise { + const stored = this.tasks.get(taskId); + if (!stored) { + throw new Error(`Task with ID ${taskId} not found`); + } + + stored.result = result; + stored.task.status = 'completed'; + + // Reset cleanup timer to start from now (if keepAlive is set) + if (stored.task.keepAlive) { + const existingTimer = this.cleanupTimers.get(taskId); + if (existingTimer) { + clearTimeout(existingTimer); + } + + const timer = setTimeout(() => { + this.tasks.delete(taskId); + this.cleanupTimers.delete(taskId); + }, stored.task.keepAlive); + + this.cleanupTimers.set(taskId, timer); + } + } + + async getTaskResult(taskId: string): Promise { + const stored = this.tasks.get(taskId); + if (!stored) { + throw new Error(`Task with ID ${taskId} not found`); + } + + if (!stored.result) { + throw new Error(`Task ${taskId} has no result stored`); + } + + return stored.result; + } + + async updateTaskStatus(taskId: string, status: Task['status'], error?: string): Promise { + const stored = this.tasks.get(taskId); + if (!stored) { + throw new Error(`Task with ID ${taskId} not found`); + } + + stored.task.status = status; + if (error) { + stored.task.error = error; + } + + // If task is in a terminal state and has keepAlive, start cleanup timer + if (isTerminal(status) && stored.task.keepAlive) { + const existingTimer = this.cleanupTimers.get(taskId); + if (existingTimer) { + clearTimeout(existingTimer); + } + + const timer = setTimeout(() => { + this.tasks.delete(taskId); + this.cleanupTimers.delete(taskId); + }, stored.task.keepAlive); + + this.cleanupTimers.set(taskId, timer); + } + } + + /** + * Cleanup all timers (useful for testing or graceful shutdown) + */ + cleanup(): void { + for (const timer of this.cleanupTimers.values()) { + clearTimeout(timer); + } + this.cleanupTimers.clear(); + this.tasks.clear(); + } + + /** + * Get all tasks (useful for debugging) + */ + getAllTasks(): Task[] { + return Array.from(this.tasks.values()).map(stored => ({ ...stored.task })); + } +} From c73b10567ca97a8ddc16f9c23dd0b303a33c8ec9 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 23 Oct 2025 12:57:40 -0700 Subject: [PATCH 08/84] Implement input_required status for tasks --- src/examples/client/simpleStreamableHttp.ts | 10 ++++++++-- src/examples/server/simpleStreamableHttp.ts | 19 +++++++++++++------ src/shared/protocol.ts | 11 ++++++++++- src/types.ts | 2 +- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index 697353ef4..0b84cdfa1 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -17,7 +17,8 @@ import { ElicitRequestSchema, ResourceLink, ReadResourceRequest, - ReadResourceResultSchema + ReadResourceResultSchema, + RELATED_TASK_META_KEY } from '../../types.js'; import { getDisplayName } from '../../shared/metadataUtils.js'; import Ajv from 'ajv'; @@ -249,6 +250,7 @@ async function connect(url?: string): Promise { client.setRequestHandler(ElicitRequestSchema, async request => { console.log('\n🔔 Elicitation Request Received:'); console.log(`Message: ${request.params.message}`); + console.log(`Related Task: ${request.params._meta?.[RELATED_TASK_META_KEY]?.taskId}`); console.log('Requested Schema:'); console.log(JSON.stringify(request.params.requestedSchema, null, 2)); @@ -827,12 +829,16 @@ async function callToolTask(name: string, args: Record): Promis console.log('Waiting for task completion...'); + let lastStatus = ''; await pendingRequest.result({ onTaskCreated: () => { console.log('Task created successfully'); }, onTaskStatus: task => { - console.log(` ${task.status}${task.error ? ` - ${task.error}` : ''}`); + if (lastStatus !== task.status) { + console.log(` ${task.status}${task.error ? ` - ${task.error}` : ''}`); + } + lastStatus = task.status; } }); diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 966337f45..ec73a4f02 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -7,6 +7,7 @@ import { getOAuthProtectedResourceMetadataUrl, mcpAuthMetadataRouter } from '../ import { requireBearerAuth } from '../../server/auth/middleware/bearerAuth.js'; import { CallToolResult, + ElicitResultSchema, GetPromptResult, isInitializeRequest, PrimitiveSchemaDefinition, @@ -126,7 +127,7 @@ const getServer = () => { { infoType: z.enum(['contact', 'preferences', 'feedback']).describe('Type of information to collect') }, - async ({ infoType }): Promise => { + async ({ infoType }, extra): Promise => { let message: string; let requestedSchema: { type: 'object'; @@ -221,11 +222,17 @@ const getServer = () => { } try { - // Use the underlying server instance to elicit input from the client - const result = await server.server.elicitInput({ - message, - requestedSchema - }); + // Elicit input from the client + const result = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + message, + requestedSchema + } + }, + ElicitResultSchema + ); if (result.action === 'accept') { return { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 6bbdbf84d..382440fe6 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -492,7 +492,16 @@ export abstract class Protocol { const relatedTask = taskMetadata ? { taskId: taskMetadata.taskId } : undefined; - return await this.request(r, resultSchema, { ...options, relatedRequestId: request.id, relatedTask }); + if (taskMetadata && this._taskStore) { + await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'input_required'); + } + try { + return await this.request(r, resultSchema, { ...options, relatedRequestId: request.id, relatedTask }); + } finally { + if (taskMetadata && this._taskStore) { + await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'working'); + } + } }, authInfo: extra?.authInfo, requestId: request.id, diff --git a/src/types.ts b/src/types.ts index 23caa5f41..24ea5881d 100644 --- a/src/types.ts +++ b/src/types.ts @@ -501,7 +501,7 @@ export const ProgressNotificationSchema = NotificationSchema.extend({ */ export const TaskSchema = z.object({ taskId: z.string(), - status: z.enum(['submitted', 'working', 'completed', 'failed', 'cancelled', 'unknown']), + status: z.enum(['submitted', 'working', 'input_required', 'completed', 'failed', 'cancelled', 'unknown']), keepAlive: z.union([z.number(), z.null()]), pollFrequency: z.optional(z.number()), error: z.optional(z.string()) From b028061b83b060ee0545f92f4e6a7206a356aa25 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 23 Oct 2025 14:38:04 -0700 Subject: [PATCH 09/84] Implement unit tests for task support --- src/examples/shared/inMemoryTaskStore.test.ts | 374 ++++++++++++++ src/server/index.test.ts | 259 ++++++++++ src/shared/protocol.test.ts | 469 ++++++++++++++++++ src/shared/protocol.ts | 40 +- 4 files changed, 1120 insertions(+), 22 deletions(-) create mode 100644 src/examples/shared/inMemoryTaskStore.test.ts diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts new file mode 100644 index 000000000..2e4020a7f --- /dev/null +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -0,0 +1,374 @@ +import { InMemoryTaskStore } from './inMemoryTaskStore.js'; +import { TaskMetadata, Request } from '../../types.js'; + +describe('InMemoryTaskStore', () => { + let store: InMemoryTaskStore; + + beforeEach(() => { + store = new InMemoryTaskStore(); + }); + + afterEach(() => { + store.cleanup(); + }); + + describe('createTask', () => { + it('should create a new task with submitted status', async () => { + const metadata: TaskMetadata = { + taskId: 'task-1', + keepAlive: 60000 + }; + const request: Request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + await store.createTask(metadata, 123, request); + + const task = await store.getTask('task-1'); + expect(task).toBeDefined(); + expect(task?.taskId).toBe('task-1'); + expect(task?.status).toBe('submitted'); + expect(task?.keepAlive).toBe(60000); + expect(task?.pollFrequency).toBe(500); + }); + + it('should create task without keepAlive', async () => { + const metadata: TaskMetadata = { + taskId: 'task-no-keepalive' + }; + const request: Request = { + method: 'tools/call', + params: {} + }; + + await store.createTask(metadata, 456, request); + + const task = await store.getTask('task-no-keepalive'); + expect(task).toBeDefined(); + expect(task?.keepAlive).toBeNull(); + }); + + it('should reject duplicate taskId', async () => { + const metadata: TaskMetadata = { + taskId: 'duplicate-task' + }; + const request: Request = { + method: 'tools/call', + params: {} + }; + + await store.createTask(metadata, 789, request); + + await expect(store.createTask(metadata, 790, request)).rejects.toThrow('Task with ID duplicate-task already exists'); + }); + }); + + describe('getTask', () => { + it('should return null for non-existent task', async () => { + const task = await store.getTask('non-existent'); + expect(task).toBeNull(); + }); + + it('should return task state', async () => { + const metadata: TaskMetadata = { + taskId: 'get-test' + }; + const request: Request = { + method: 'tools/call', + params: {} + }; + + await store.createTask(metadata, 111, request); + await store.updateTaskStatus('get-test', 'working'); + + const task = await store.getTask('get-test'); + expect(task).toBeDefined(); + expect(task?.status).toBe('working'); + }); + }); + + describe('updateTaskStatus', () => { + beforeEach(async () => { + const metadata: TaskMetadata = { + taskId: 'status-test' + }; + await store.createTask(metadata, 222, { + method: 'tools/call', + params: {} + }); + }); + + it('should update task status from submitted to working', async () => { + await store.updateTaskStatus('status-test', 'working'); + + const task = await store.getTask('status-test'); + expect(task?.status).toBe('working'); + }); + + it('should update task status to input_required', async () => { + await store.updateTaskStatus('status-test', 'input_required'); + + const task = await store.getTask('status-test'); + expect(task?.status).toBe('input_required'); + }); + + it('should update task status to completed', async () => { + await store.updateTaskStatus('status-test', 'completed'); + + const task = await store.getTask('status-test'); + expect(task?.status).toBe('completed'); + }); + + it('should update task status to failed with error', async () => { + await store.updateTaskStatus('status-test', 'failed', 'Something went wrong'); + + const task = await store.getTask('status-test'); + expect(task?.status).toBe('failed'); + expect(task?.error).toBe('Something went wrong'); + }); + + it('should update task status to cancelled', async () => { + await store.updateTaskStatus('status-test', 'cancelled'); + + const task = await store.getTask('status-test'); + expect(task?.status).toBe('cancelled'); + }); + + it('should throw if task not found', async () => { + await expect(store.updateTaskStatus('non-existent', 'working')).rejects.toThrow('Task with ID non-existent not found'); + }); + }); + + describe('storeTaskResult', () => { + beforeEach(async () => { + const metadata: TaskMetadata = { + taskId: 'result-test', + keepAlive: 60000 + }; + await store.createTask(metadata, 333, { + method: 'tools/call', + params: {} + }); + }); + + it('should store task result and set status to completed', async () => { + const result = { + content: [{ type: 'text' as const, text: 'Success!' }] + }; + + await store.storeTaskResult('result-test', result); + + const task = await store.getTask('result-test'); + expect(task?.status).toBe('completed'); + + const storedResult = await store.getTaskResult('result-test'); + expect(storedResult).toEqual(result); + }); + + it('should throw if task not found', async () => { + await expect(store.storeTaskResult('non-existent', {})).rejects.toThrow('Task with ID non-existent not found'); + }); + }); + + describe('getTaskResult', () => { + it('should throw if task not found', async () => { + await expect(store.getTaskResult('non-existent')).rejects.toThrow('Task with ID non-existent not found'); + }); + + it('should throw if task has no result stored', async () => { + const metadata: TaskMetadata = { + taskId: 'no-result' + }; + await store.createTask(metadata, 444, { + method: 'tools/call', + params: {} + }); + + await expect(store.getTaskResult('no-result')).rejects.toThrow('Task no-result has no result stored'); + }); + + it('should return stored result', async () => { + const metadata: TaskMetadata = { + taskId: 'with-result' + }; + await store.createTask(metadata, 555, { + method: 'tools/call', + params: {} + }); + + const result = { + content: [{ type: 'text' as const, text: 'Result data' }] + }; + await store.storeTaskResult('with-result', result); + + const retrieved = await store.getTaskResult('with-result'); + expect(retrieved).toEqual(result); + }); + }); + + describe('keepAlive cleanup', () => { + beforeEach(() => { + jest.useFakeTimers(); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('should cleanup task after keepAlive duration', async () => { + const metadata: TaskMetadata = { + taskId: 'cleanup-test', + keepAlive: 1000 + }; + await store.createTask(metadata, 666, { + method: 'tools/call', + params: {} + }); + + // Task should exist initially + let task = await store.getTask('cleanup-test'); + expect(task).toBeDefined(); + + // Fast-forward past keepAlive + jest.advanceTimersByTime(1001); + + // Task should be cleaned up + task = await store.getTask('cleanup-test'); + expect(task).toBeNull(); + }); + + it('should reset cleanup timer when result is stored', async () => { + const metadata: TaskMetadata = { + taskId: 'reset-cleanup', + keepAlive: 1000 + }; + await store.createTask(metadata, 777, { + method: 'tools/call', + params: {} + }); + + // Fast-forward 500ms + jest.advanceTimersByTime(500); + + // Store result (should reset timer) + await store.storeTaskResult('reset-cleanup', { + content: [{ type: 'text' as const, text: 'Done' }] + }); + + // Fast-forward another 500ms (total 1000ms since creation, but timer was reset) + jest.advanceTimersByTime(500); + + // Task should still exist + const task = await store.getTask('reset-cleanup'); + expect(task).toBeDefined(); + + // Fast-forward remaining time + jest.advanceTimersByTime(501); + + // Now task should be cleaned up + const cleanedTask = await store.getTask('reset-cleanup'); + expect(cleanedTask).toBeNull(); + }); + + it('should not cleanup tasks without keepAlive', async () => { + const metadata: TaskMetadata = { + taskId: 'no-cleanup' + }; + await store.createTask(metadata, 888, { + method: 'tools/call', + params: {} + }); + + // Fast-forward a long time + jest.advanceTimersByTime(100000); + + // Task should still exist + const task = await store.getTask('no-cleanup'); + expect(task).toBeDefined(); + }); + + it('should start cleanup timer when task reaches terminal state', async () => { + const metadata: TaskMetadata = { + taskId: 'terminal-cleanup', + keepAlive: 1000 + }; + await store.createTask(metadata, 999, { + method: 'tools/call', + params: {} + }); + + // Task in non-terminal state, fast-forward + jest.advanceTimersByTime(1001); + + // Task should be cleaned up + let task = await store.getTask('terminal-cleanup'); + expect(task).toBeNull(); + + // Create another task + const metadata2: TaskMetadata = { + taskId: 'terminal-cleanup-2', + keepAlive: 2000 + }; + await store.createTask(metadata2, 1000, { + method: 'tools/call', + params: {} + }); + + // Update to terminal state + await store.updateTaskStatus('terminal-cleanup-2', 'completed'); + + // Fast-forward past original keepAlive + jest.advanceTimersByTime(2001); + + // Task should be cleaned up + task = await store.getTask('terminal-cleanup-2'); + expect(task).toBeNull(); + }); + }); + + describe('getAllTasks', () => { + it('should return all tasks', async () => { + await store.createTask({ taskId: 'task-1' }, 1, { + method: 'tools/call', + params: {} + }); + await store.createTask({ taskId: 'task-2' }, 2, { + method: 'tools/call', + params: {} + }); + await store.createTask({ taskId: 'task-3' }, 3, { + method: 'tools/call', + params: {} + }); + + const tasks = store.getAllTasks(); + expect(tasks).toHaveLength(3); + expect(tasks.map(t => t.taskId).sort()).toEqual(['task-1', 'task-2', 'task-3']); + }); + + it('should return empty array when no tasks', () => { + const tasks = store.getAllTasks(); + expect(tasks).toEqual([]); + }); + }); + + describe('cleanup', () => { + it('should clear all timers and tasks', async () => { + await store.createTask({ taskId: 'task-1', keepAlive: 1000 }, 1, { + method: 'tools/call', + params: {} + }); + await store.createTask({ taskId: 'task-2', keepAlive: 2000 }, 2, { + method: 'tools/call', + params: {} + }); + + expect(store.getAllTasks()).toHaveLength(2); + + store.cleanup(); + + expect(store.getAllTasks()).toHaveLength(0); + }); + }); +}); diff --git a/src/server/index.test.ts b/src/server/index.test.ts index d056707fe..6d74707cc 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -21,6 +21,8 @@ import { import { Transport } from '../shared/transport.js'; import { InMemoryTransport } from '../inMemory.js'; import { Client } from '../client/index.js'; +import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; +import { CallToolRequestSchema, CallToolResultSchema } from '../types.js'; test('should accept latest protocol version', async () => { let sendPromiseResolve: (value: unknown) => void; @@ -955,3 +957,260 @@ test('should respect log level for transport with sessionId', async () => { await server.sendLoggingMessage(warningParams, SESSION_ID); expect(clientTransport.onmessage).toHaveBeenCalled(); }); + +describe('Task-based execution', () => { + test('server with TaskStore should handle task-based tool execution', async () => { + const taskStore = new InMemoryTaskStore(); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore + } + ); + + // Set up a tool handler that simulates some work + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + // Simulate some async work + await new Promise(resolve => setTimeout(resolve, 10)); + return { + content: [{ type: 'text', text: 'Tool executed successfully!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Use beginCallTool to create a task + const taskId = 'test-task-123'; + const pendingRequest = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { + taskId, + keepAlive: 60000 + } + }); + + // Wait for the task to complete + await pendingRequest.result(); + + // Verify we can retrieve the task + const task = await client.getTask({ taskId }); + expect(task).toBeDefined(); + expect(task.status).toBe('completed'); + + // Verify we can retrieve the result + const result = await client.getTaskResult({ taskId }, CallToolResultSchema); + expect(result.content).toEqual([{ type: 'text', text: 'Tool executed successfully!' }]); + + // Cleanup + taskStore.cleanup(); + }); + + test('server without TaskStore should reject task-based requests', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + // No taskStore configured + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: 'Success!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to get a task when server doesn't have TaskStore + // The server will return a "Method not found" error + await expect(client.getTask({ taskId: 'non-existent' })).rejects.toThrow('Method not found'); + }); + + test('should automatically attach related-task metadata to nested requests during tool execution', async () => { + const taskStore = new InMemoryTaskStore(); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + // Track the elicitation request to verify related-task metadata + let capturedElicitRequest: z.infer | null = null; + + // Set up client elicitation handler + client.setRequestHandler(ElicitRequestSchema, async request => { + // Capture the request to verify metadata later + capturedElicitRequest = request; + + return { + action: 'accept', + content: { + username: 'test-user' + } + }; + }); + + // Set up server tool that makes a nested elicitation request + server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (request.params.name === 'collect-info') { + // During tool execution, make a nested request to the client using extra.sendRequest + // This should AUTOMATICALLY attach the related-task metadata + const elicitResult = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' } + }, + required: ['username'] + } + } + }, + z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }) + ); + + return { + content: [ + { + type: 'text', + text: `Collected username: ${elicitResult.action === 'accept' && elicitResult.content ? (elicitResult.content as Record).username : 'none'}` + } + ] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'collect-info', + description: 'Collects user info via elicitation', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Call tool WITH task metadata + const taskId = 'test-task-456'; + const pendingRequest = client.beginCallTool({ name: 'collect-info', arguments: {} }, CallToolResultSchema, { + task: { + taskId, + keepAlive: 60000 + } + }); + + // Wait for completion + await pendingRequest.result(); + + // Verify the nested elicitation request received the related-task metadata + expect(capturedElicitRequest).toBeDefined(); + expect(capturedElicitRequest!.params._meta).toBeDefined(); + expect(capturedElicitRequest!.params._meta?.['modelcontextprotocol.io/related-task']).toEqual({ + taskId: 'test-task-456' + }); + + // Verify tool result was correct + const result = await client.getTaskResult({ taskId }, CallToolResultSchema); + expect(result.content).toEqual([ + { + type: 'text', + text: 'Collected username: test-user' + } + ]); + + // Cleanup + taskStore.cleanup(); + }); +}); diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 1c098eafa..4eccfbd91 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -742,3 +742,472 @@ describe('mergeCapabilities', () => { expect(merged).toEqual({}); }); }); + +describe('Task-based execution', () => { + let protocol: Protocol; + let transport: MockTransport; + let sendSpy: jest.SpyInstance; + + beforeEach(() => { + transport = new MockTransport(); + sendSpy = jest.spyOn(transport, 'send'); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })(); + }); + + describe('beginRequest with task metadata', () => { + it('should inject task metadata into _meta field', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + protocol.beginRequest(request, resultSchema, { + task: { + taskId: 'my-task-123', + keepAlive: 30000 + } + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'tools/call', + params: { + name: 'test-tool', + _meta: { + 'modelcontextprotocol.io/task': { + taskId: 'my-task-123', + keepAlive: 30000 + } + } + } + }), + expect.any(Object) + ); + }); + + it('should preserve existing _meta when adding task metadata', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { + name: 'test-tool', + _meta: { + customField: 'customValue' + } + } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + protocol.beginRequest(request, resultSchema, { + task: { + taskId: 'my-task-456' + } + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + params: { + name: 'test-tool', + _meta: { + customField: 'customValue', + 'modelcontextprotocol.io/task': { + taskId: 'my-task-456' + } + } + } + }), + expect.any(Object) + ); + }); + + it('should return PendingRequest object', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + const pendingRequest = protocol.beginRequest(request, resultSchema, { + task: { + taskId: 'my-task-789' + } + }); + + expect(pendingRequest).toBeDefined(); + expect(pendingRequest.taskId).toBe('my-task-789'); + }); + }); + + describe('relatedTask metadata', () => { + it('should inject relatedTask metadata into _meta field', async () => { + await protocol.connect(transport); + + const request = { + method: 'notifications/message', + params: { data: 'test' } + }; + + const resultSchema = z.object({}); + + protocol.beginRequest(request, resultSchema, { + relatedTask: { + taskId: 'parent-task-123' + } + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + params: { + data: 'test', + _meta: { + 'modelcontextprotocol.io/related-task': { + taskId: 'parent-task-123' + } + } + } + }), + expect.any(Object) + ); + }); + + it('should work with notification method', async () => { + await protocol.connect(transport); + + await protocol.notification( + { + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }, + { + relatedTask: { + taskId: 'parent-task-456' + } + } + ); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'notifications/message', + params: { + level: 'info', + data: 'test message', + _meta: { + 'modelcontextprotocol.io/related-task': { + taskId: 'parent-task-456' + } + } + } + }), + expect.any(Object) + ); + }); + }); + + describe('task metadata combination', () => { + it('should combine task, relatedTask, and progress metadata', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + protocol.beginRequest(request, resultSchema, { + task: { + taskId: 'my-task-combined' + }, + relatedTask: { + taskId: 'parent-task' + }, + onprogress: jest.fn() + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + params: { + name: 'test-tool', + _meta: { + 'modelcontextprotocol.io/task': { + taskId: 'my-task-combined' + }, + 'modelcontextprotocol.io/related-task': { + taskId: 'parent-task' + }, + progressToken: expect.any(Number) + } + } + }), + expect.any(Object) + ); + }); + }); + + describe('task status transitions', () => { + it('should transition from submitted to working when handler starts', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async () => ({ + result: 'success' + })); + + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'test/method', + params: { + _meta: { + 'modelcontextprotocol.io/task': { + taskId: 'test-task', + keepAlive: 60000 + } + } + } + }); + + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(mockTaskStore.createTask).toHaveBeenCalledWith({ taskId: 'test-task', keepAlive: 60000 }, 1, { + method: 'test/method', + params: expect.any(Object) + }); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working'); + }); + + it('should transition to input_required during extra.sendRequest', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + }; + + const responsiveTransport = new MockTransport(); + responsiveTransport.send = jest.fn().mockImplementation(async (message: unknown) => { + if ( + typeof message === 'object' && + message !== null && + 'method' in message && + 'id' in message && + message.method === 'nested/request' && + responsiveTransport.onmessage + ) { + setTimeout(() => { + responsiveTransport.onmessage?.({ + jsonrpc: '2.0', + id: (message as { id: number }).id, + result: { nested: 'response' } + }); + }, 5); + } + }); + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(responsiveTransport); + + const capturedUpdateCalls: Array<{ taskId: string; status: string }> = []; + mockTaskStore.updateTaskStatus.mockImplementation((taskId, status) => { + capturedUpdateCalls.push({ taskId, status }); + return Promise.resolve(); + }); + + protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async (_request, extra) => { + await extra.sendRequest({ method: 'nested/request', params: {} }, z.object({ nested: z.string() })); + return { result: 'success' }; + }); + + responsiveTransport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'test/method', + params: { + _meta: { + 'modelcontextprotocol.io/task': { + taskId: 'test-task', + keepAlive: 60000 + } + } + } + }); + + await new Promise(resolve => setTimeout(resolve, 100)); + + expect(capturedUpdateCalls).toContainEqual({ taskId: 'test-task', status: 'working' }); + expect(capturedUpdateCalls).toContainEqual({ taskId: 'test-task', status: 'input_required' }); + + const inputRequiredIndex = capturedUpdateCalls.findIndex(c => c.status === 'input_required'); + const workingCalls = capturedUpdateCalls.filter(c => c.status === 'working'); + expect(workingCalls).toHaveLength(2); + + let workingCount = 0; + const secondWorkingIndex = capturedUpdateCalls.findIndex(c => { + if (c.status === 'working') { + workingCount++; + return workingCount === 2; + } + return false; + }); + expect(secondWorkingIndex).toBeGreaterThan(inputRequiredIndex); + }); + + it('should mark task as completed when storeTaskResult is called', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async () => ({ + result: 'success' + })); + + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'test/method', + params: { + _meta: { + 'modelcontextprotocol.io/task': { + taskId: 'test-task', + keepAlive: 60000 + } + } + } + }); + + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(mockTaskStore.storeTaskResult).toHaveBeenCalledWith('test-task', { result: 'success' }); + }); + + it('should mark task as cancelled when notifications/cancelled is received', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue({ taskId: 'test-task', status: 'working', keepAlive: null, pollFrequency: 500 }), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + void protocol.request({ method: 'test/slow', params: {} }, z.object({ result: z.string() }), { + task: { taskId: 'test-task', keepAlive: 60000 } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + transport.onmessage?.({ + jsonrpc: '2.0', + method: 'notifications/cancelled', + params: { + requestId: 0, + reason: 'User cancelled' + } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'cancelled'); + }); + + it('should mark task as failed when updateTaskStatus to working fails', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockRejectedValueOnce(new Error('Failed to update status')).mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async () => ({ + result: 'success' + })); + + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'test/method', + params: { + _meta: { + 'modelcontextprotocol.io/task': { + taskId: 'test-task', + keepAlive: 60000 + } + } + } + }); + + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working'); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'failed', 'Failed to mark task as working'); + }); + }); +}); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 382440fe6..6d811953a 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -518,29 +518,25 @@ export abstract class Protocol { @@ -548,7 +544,7 @@ export abstract class Protocol Date: Thu, 23 Oct 2025 16:24:41 -0700 Subject: [PATCH 10/84] Add docs for task augmentation --- README.md | 164 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) diff --git a/README.md b/README.md index 92f56786f..47588d600 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ - [Improving Network Efficiency with Notification Debouncing](#improving-network-efficiency-with-notification-debouncing) - [Low-Level Server](#low-level-server) - [Eliciting User Input](#eliciting-user-input) + - [Task-Based Execution](#task-based-execution) - [Writing MCP Clients](#writing-mcp-clients) - [Proxy Authorization Requests Upstream](#proxy-authorization-requests-upstream) - [Backwards Compatibility](#backwards-compatibility) @@ -1301,6 +1302,169 @@ client.setRequestHandler(ElicitRequestSchema, async request => { **Note**: Elicitation requires client support. Clients must declare the `elicitation` capability during initialization. +### Task-Based Execution + +Task-based execution enables "call-now, fetch-later" patterns for long-running operations. This is useful for tools that take significant time to complete, where clients may want to disconnect and check on progress or retrieve results later. + +Common use cases include: + +- Long-running data processing or analysis +- Code migration or refactoring operations +- Complex computational tasks +- Operations that require periodic status updates + +#### Server-Side: Implementing Task Support + +To enable task-based execution, configure your server with a `TaskStore` implementation. The SDK doesn't provide a built-in TaskStore—you'll need to implement one backed by your database of choice: + +```typescript +import { Server } from '@modelcontextprotocol/sdk/server/index.js'; +import { TaskStore } from '@modelcontextprotocol/sdk/shared/task.js'; +import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'; + +// Implement TaskStore backed by your database (e.g., PostgreSQL, Redis, etc.) +class MyTaskStore implements TaskStore { + async createTask(metadata, requestId, request) { + // Store task in your database + } + + async getTask(taskId) { + // Retrieve task from your database + } + + async updateTaskStatus(taskId, status, errorMessage?) { + // Update task status in your database + } + + async storeTaskResult(taskId, result) { + // Store task result in your database + } + + async getTaskResult(taskId) { + // Retrieve task result from your database + } +} + +const taskStore = new MyTaskStore(); + +const server = new Server( + { + name: 'task-enabled-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore // Enable task support + } +); + +// Set up a long-running tool handler as usual +server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'analyze-data') { + // Simulate long-running analysis + await new Promise(resolve => setTimeout(resolve, 30000)); + + return { + content: [ + { + type: 'text', + text: 'Analysis complete!' + } + ] + }; + } + throw new Error('Unknown tool'); +}); + +server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'analyze-data', + description: 'Perform data analysis (long-running)', + inputSchema: { + type: 'object', + properties: { + dataset: { type: 'string' } + } + } + } + ] +})); +``` + +**Note**: See `src/examples/shared/inMemoryTaskStore.ts` in the SDK source for a reference implementation suitable for development and testing. + +#### Client-Side: Using Task-Based Execution + +Clients use `beginCallTool()` to initiate task-based operations. The returned `PendingRequest` object provides automatic polling and status tracking: + +```typescript +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; + +const client = new Client({ + name: 'task-client', + version: '1.0.0' +}); + +// ... connect to server ... + +// Initiate a task-based tool call +const taskId = 'analysis-task-123'; +const pendingRequest = client.beginCallTool( + { + name: 'analyze-data', + arguments: { dataset: 'user-data.csv' } + }, + CallToolResultSchema, + { + task: { + taskId, + keepAlive: 300000 // Keep results for 5 minutes after completion + } + } +); + +// Option 1: Wait for completion with status callbacks +const result = await pendingRequest.result({ + onTaskCreated: () => { + console.log('Task created successfully'); + }, + onTaskStatus: task => { + console.log(`Task status: ${task.status}`); + // Status can be: 'submitted', 'working', 'input_required', 'completed', 'failed', or 'cancelled' + } +}); +console.log('Task completed:', result); + +// Option 2: Fire and forget - disconnect and reconnect later +// (useful when you don't want to wait for long-running tasks) +// Later, after disconnecting and reconnecting to the server: +const taskStatus = await client.getTask({ taskId }); +console.log('Task status:', taskStatus.status); + +if (taskStatus.status === 'completed') { + const taskResult = await client.getTaskResult({ taskId }, CallToolResultSchema); + console.log('Retrieved result after reconnect:', taskResult); +} +``` + +#### Task Status Lifecycle + +Tasks transition through the following states: + +- **submitted**: Task has been created and queued +- **working**: Task is actively being processed +- **input_required**: Task is waiting for additional input (e.g., from elicitation) +- **completed**: Task finished successfully +- **failed**: Task encountered an error +- **cancelled**: Task was cancelled by the client +- **unknown**: Task status could not be determined (terminal state, rarely occurs) + +The `keepAlive` parameter determines how long the server retains task results after completion. This allows clients to retrieve results even after disconnecting and reconnecting. + ### Writing MCP Clients The SDK provides a high-level client interface: From 5dc999f60593132087f2026ba3497aa4941e96eb Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 27 Oct 2025 12:19:35 -0700 Subject: [PATCH 11/84] Implement tasks/list method --- src/examples/shared/inMemoryTaskStore.test.ts | 73 ++++++ src/examples/shared/inMemoryTaskStore.ts | 26 ++ src/shared/protocol.test.ts | 248 +++++++++++++++++- src/shared/protocol.ts | 27 ++ src/shared/task.ts | 8 + src/types.ts | 28 +- 6 files changed, 401 insertions(+), 9 deletions(-) diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts index 2e4020a7f..9c8c7dab0 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -353,6 +353,79 @@ describe('InMemoryTaskStore', () => { }); }); + describe('listTasks', () => { + it('should return empty list when no tasks', async () => { + const result = await store.listTasks(); + expect(result.tasks).toEqual([]); + expect(result.nextCursor).toBeUndefined(); + }); + + it('should return all tasks when less than page size', async () => { + await store.createTask({ taskId: 'task-1' }, 1, { + method: 'tools/call', + params: {} + }); + await store.createTask({ taskId: 'task-2' }, 2, { + method: 'tools/call', + params: {} + }); + await store.createTask({ taskId: 'task-3' }, 3, { + method: 'tools/call', + params: {} + }); + + const result = await store.listTasks(); + expect(result.tasks).toHaveLength(3); + expect(result.nextCursor).toBeUndefined(); + }); + + it('should paginate when more than page size', async () => { + // Create 15 tasks (page size is 10) + for (let i = 1; i <= 15; i++) { + await store.createTask({ taskId: `task-${i}` }, i, { + method: 'tools/call', + params: {} + }); + } + + // Get first page + const page1 = await store.listTasks(); + expect(page1.tasks).toHaveLength(10); + expect(page1.nextCursor).toBeDefined(); + + // Get second page using cursor + const page2 = await store.listTasks(page1.nextCursor); + expect(page2.tasks).toHaveLength(5); + expect(page2.nextCursor).toBeUndefined(); + }); + + it('should throw error for invalid cursor', async () => { + await store.createTask({ taskId: 'task-1' }, 1, { + method: 'tools/call', + params: {} + }); + + await expect(store.listTasks('non-existent-cursor')).rejects.toThrow('Invalid cursor: non-existent-cursor'); + }); + + it('should continue from cursor correctly', async () => { + // Create tasks with predictable IDs + for (let i = 1; i <= 5; i++) { + await store.createTask({ taskId: `task-${i}` }, i, { + method: 'tools/call', + params: {} + }); + } + + // Get first 3 tasks + const allTaskIds = Array.from(store.getAllTasks().map(t => t.taskId)); + const result = await store.listTasks(allTaskIds[2]); + + // Should get tasks after task-3 + expect(result.tasks).toHaveLength(2); + }); + }); + describe('cleanup', () => { it('should clear all timers and tasks', async () => { await store.createTask({ taskId: 'task-1', keepAlive: 1000 }, 1, { diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index 79d8a05bd..c9f297c86 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -122,6 +122,32 @@ export class InMemoryTaskStore implements TaskStore { } } + async listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }> { + const PAGE_SIZE = 10; + const allTaskIds = Array.from(this.tasks.keys()); + + let startIndex = 0; + if (cursor) { + const cursorIndex = allTaskIds.indexOf(cursor); + if (cursorIndex >= 0) { + startIndex = cursorIndex + 1; + } else { + // Invalid cursor - throw error + throw new Error(`Invalid cursor: ${cursor}`); + } + } + + const pageTaskIds = allTaskIds.slice(startIndex, startIndex + PAGE_SIZE); + const tasks = pageTaskIds.map(taskId => { + const stored = this.tasks.get(taskId)!; + return { ...stored.task }; + }); + + const nextCursor = startIndex + PAGE_SIZE < allTaskIds.length ? pageTaskIds[pageTaskIds.length - 1] : undefined; + + return { tasks, nextCursor }; + } + /** * Cleanup all timers (useful for testing or graceful shutdown) */ diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 4eccfbd91..a84e5a0ec 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -972,7 +972,8 @@ describe('Task-based execution', () => { getTask: jest.fn().mockResolvedValue(null), updateTaskStatus: jest.fn().mockResolvedValue(undefined), storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) }; protocol = new (class extends Protocol { @@ -1016,7 +1017,8 @@ describe('Task-based execution', () => { getTask: jest.fn().mockResolvedValue(null), updateTaskStatus: jest.fn().mockResolvedValue(undefined), storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) }; const responsiveTransport = new MockTransport(); @@ -1098,7 +1100,8 @@ describe('Task-based execution', () => { getTask: jest.fn().mockResolvedValue(null), updateTaskStatus: jest.fn().mockResolvedValue(undefined), storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) }; protocol = new (class extends Protocol { @@ -1138,7 +1141,8 @@ describe('Task-based execution', () => { getTask: jest.fn().mockResolvedValue({ taskId: 'test-task', status: 'working', keepAlive: null, pollFrequency: 500 }), updateTaskStatus: jest.fn().mockResolvedValue(undefined), storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) }; protocol = new (class extends Protocol { @@ -1175,7 +1179,8 @@ describe('Task-based execution', () => { getTask: jest.fn().mockResolvedValue(null), updateTaskStatus: jest.fn().mockRejectedValueOnce(new Error('Failed to update status')).mockResolvedValue(undefined), storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) }; protocol = new (class extends Protocol { @@ -1210,4 +1215,237 @@ describe('Task-based execution', () => { expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'failed', 'Failed to mark task as working'); }); }); + + describe('listTasks', () => { + it('should handle tasks/list requests and return tasks from TaskStore', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ + tasks: [ + { taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }, + { taskId: 'task-2', status: 'working', keepAlive: 60000, pollFrequency: 1000 } + ], + nextCursor: 'task-2' + }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + // Simulate receiving a tasks/list request + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'tasks/list', + params: {} + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined); + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(1); + expect(sentMessage.result.tasks).toEqual([ + { taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }, + { taskId: 'task-2', status: 'working', keepAlive: 60000, pollFrequency: 1000 } + ]); + expect(sentMessage.result.nextCursor).toBe('task-2'); + expect(sentMessage.result._meta).toEqual({}); + }); + + it('should handle tasks/list requests with cursor for pagination', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ + tasks: [{ taskId: 'task-3', status: 'submitted', keepAlive: null, pollFrequency: 500 }], + nextCursor: undefined + }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + // Simulate receiving a tasks/list request with cursor + transport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/list', + params: { + cursor: 'task-2' + } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.listTasks).toHaveBeenCalledWith('task-2'); + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(2); + expect(sentMessage.result.tasks).toEqual([{ taskId: 'task-3', status: 'submitted', keepAlive: null, pollFrequency: 500 }]); + expect(sentMessage.result.nextCursor).toBeUndefined(); + expect(sentMessage.result._meta).toEqual({}); + }); + + it('should handle tasks/list requests with empty results', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ + tasks: [], + nextCursor: undefined + }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + // Simulate receiving a tasks/list request + transport.onmessage?.({ + jsonrpc: '2.0', + id: 3, + method: 'tasks/list', + params: {} + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined); + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(3); + expect(sentMessage.result.tasks).toEqual([]); + expect(sentMessage.result.nextCursor).toBeUndefined(); + expect(sentMessage.result._meta).toEqual({}); + }); + + it('should return error for invalid cursor', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockRejectedValue(new Error('Invalid cursor: bad-cursor')) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + // Simulate receiving a tasks/list request with invalid cursor + transport.onmessage?.({ + jsonrpc: '2.0', + id: 4, + method: 'tasks/list', + params: { + cursor: 'bad-cursor' + } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.listTasks).toHaveBeenCalledWith('bad-cursor'); + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(4); + expect(sentMessage.error).toBeDefined(); + expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.message).toContain('Failed to list tasks'); + expect(sentMessage.error.message).toContain('Invalid cursor'); + }); + + it('should call listTasks method from client side', async () => { + await protocol.connect(transport); + + const listTasksPromise = protocol.listTasks(); + + // Simulate server response + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: sendSpy.mock.calls[0][0].id, + result: { + tasks: [{ taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }], + nextCursor: undefined, + _meta: {} + } + }); + }, 10); + + const result = await listTasksPromise; + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'tasks/list', + params: undefined + }), + expect.any(Object) + ); + expect(result.tasks).toHaveLength(1); + expect(result.tasks[0].taskId).toBe('task-1'); + }); + + it('should call listTasks with cursor from client side', async () => { + await protocol.connect(transport); + + const listTasksPromise = protocol.listTasks({ cursor: 'task-10' }); + + // Simulate server response + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: sendSpy.mock.calls[0][0].id, + result: { + tasks: [{ taskId: 'task-11', status: 'working', keepAlive: 30000, pollFrequency: 1000 }], + nextCursor: 'task-11', + _meta: {} + } + }); + }, 10); + + const result = await listTasksPromise; + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'tasks/list', + params: { cursor: 'task-10' } + }), + expect.any(Object) + ); + expect(result.tasks).toHaveLength(1); + expect(result.tasks[0].taskId).toBe('task-11'); + expect(result.nextCursor).toBe('task-11'); + }); + }); }); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 6d811953a..a25107613 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -8,6 +8,8 @@ import { GetTaskResultSchema, GetTaskPayloadRequest, GetTaskPayloadRequestSchema, + ListTasksRequestSchema, + ListTasksResultSchema, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, @@ -319,6 +321,23 @@ export abstract class Protocol { + try { + const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor); + // @ts-expect-error SendResultT cannot contain ListTasksResult, but we include it in our derived types everywhere else + return { + tasks, + nextCursor, + _meta: {} + } as SendResultT; + } catch (error) { + throw new McpError( + ErrorCode.InvalidParams, + `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}` + ); + } + }); } } @@ -856,6 +875,14 @@ export abstract class Protocol> { + // @ts-expect-error SendRequestT cannot directly contain ListTasksRequest, but we ensure all type instantiations contain it anyways + return this.request({ method: 'tasks/list', params }, ListTasksResultSchema, options); + } + /** * Emits a notification, which is a one-way message that does not expect a response. */ diff --git a/src/shared/task.ts b/src/shared/task.ts index 617ab81aa..fbcd22e82 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -48,6 +48,14 @@ export interface TaskStore { * @param error - Optional error message if status is 'failed' or 'cancelled' */ updateTaskStatus(taskId: string, status: Task['status'], error?: string): Promise; + + /** + * Lists tasks, optionally starting from a pagination cursor. + * + * @param cursor - Optional cursor for pagination + * @returns An object containing the tasks array and an optional nextCursor + */ + listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; } /** diff --git a/src/types.ts b/src/types.ts index 24ea5881d..5b528f551 100644 --- a/src/types.ts +++ b/src/types.ts @@ -559,6 +559,20 @@ export const PaginatedResultSchema = ResultSchema.extend({ nextCursor: z.optional(CursorSchema) }); +/** + * A request to list tasks. + */ +export const ListTasksRequestSchema = PaginatedRequestSchema.extend({ + method: z.literal('tasks/list') +}); + +/** + * The response to a tasks/list request. + */ +export const ListTasksResultSchema = PaginatedResultSchema.extend({ + tasks: z.array(TaskSchema) +}); + /* Resources */ /** * The contents of a specific resource or sub-resource. @@ -1518,7 +1532,8 @@ export const ClientRequestSchema = z.union([ CallToolRequestSchema, ListToolsRequestSchema, GetTaskRequestSchema, - GetTaskPayloadRequestSchema + GetTaskPayloadRequestSchema, + ListTasksRequestSchema ]); export const ClientNotificationSchema = z.union([ @@ -1534,7 +1549,8 @@ export const ClientResultSchema = z.union([ CreateMessageResultSchema, ElicitResultSchema, ListRootsResultSchema, - GetTaskResultSchema + GetTaskResultSchema, + ListTasksResultSchema ]); /* Server messages */ @@ -1544,7 +1560,8 @@ export const ServerRequestSchema = z.union([ ElicitRequestSchema, ListRootsRequestSchema, GetTaskRequestSchema, - GetTaskPayloadRequestSchema + GetTaskPayloadRequestSchema, + ListTasksRequestSchema ]); export const ServerNotificationSchema = z.union([ @@ -1569,7 +1586,8 @@ export const ServerResultSchema = z.union([ ReadResourceResultSchema, CallToolResultSchema, ListToolsResultSchema, - GetTaskResultSchema + GetTaskResultSchema, + ListTasksResultSchema ]); export class McpError extends Error { @@ -1676,6 +1694,8 @@ export type TaskCreatedNotification = Infer; export type GetTaskResult = Infer; export type GetTaskPayloadRequest = Infer; +export type ListTasksRequest = Infer; +export type ListTasksResult = Infer; /* Pagination */ export type PaginatedRequest = Infer; From 71a956857edf5f125106deb627fb388caa655b1d Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 29 Oct 2025 14:10:05 -0700 Subject: [PATCH 12/84] Automatically execute tool calls as tasks --- package-lock.json | 32 ++++++++++++++++++------- package.json | 1 + src/client/index.ts | 8 ++++++- src/shared/protocol.test.ts | 48 ++++++++++++++++++++++++++----------- src/shared/protocol.ts | 37 +++++++++++++++++++++++++--- src/shared/request.ts | 5 ++-- 6 files changed, 102 insertions(+), 29 deletions(-) diff --git a/package-lock.json b/package-lock.json index 0f614d70e..8ee31c5d8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,6 +9,7 @@ "version": "1.20.2", "license": "MIT", "dependencies": { + "@lukeed/uuid": "^2.0.1", "ajv": "^8.17.1", "ajv-formats": "^3.0.1", "content-type": "^1.0.5", @@ -52,19 +53,11 @@ "node": ">=18" }, "peerDependencies": { - "@cfworker/json-schema": "^4.1.1", - "ajv": "^8.17.1", - "ajv-formats": "^3.0.1" + "@cfworker/json-schema": "^4.1.1" }, "peerDependenciesMeta": { "@cfworker/json-schema": { "optional": true - }, - "ajv": { - "optional": true - }, - "ajv-formats": { - "optional": true } } }, @@ -1610,6 +1603,27 @@ "@jridgewell/sourcemap-codec": "^1.4.14" } }, + "node_modules/@lukeed/csprng": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@lukeed/csprng/-/csprng-1.1.0.tgz", + "integrity": "sha512-Z7C/xXCiGWsg0KuKsHTKJxbWhpI3Vs5GwLfOean7MGyVFGqdRgBbAjOCh6u4bbjPc/8MJ2pZmK/0DLdCbivLDA==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/@lukeed/uuid": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/@lukeed/uuid/-/uuid-2.0.1.tgz", + "integrity": "sha512-qC72D4+CDdjGqJvkFMMEAtancHUQ7/d/tAiHf64z8MopFDmcrtbcJuerDtFceuAfQJ2pDSfCKCtbqoGBNnwg0w==", + "license": "MIT", + "dependencies": { + "@lukeed/csprng": "^1.1.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/@noble/hashes": { "version": "1.8.0", "resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.8.0.tgz", diff --git a/package.json b/package.json index 22a5b41cc..d7b601af3 100644 --- a/package.json +++ b/package.json @@ -75,6 +75,7 @@ "client": "tsx src/cli.ts client" }, "dependencies": { + "@lukeed/uuid": "^2.0.1", "ajv": "^8.17.1", "ajv-formats": "^3.0.1", "content-type": "^1.0.5", diff --git a/src/client/index.ts b/src/client/index.ts index a4c70c581..da66b1102 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,6 +1,7 @@ import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; import type { Transport } from '../shared/transport.js'; import { PendingRequest } from '../shared/request.js'; +import { v4 as uuidv4 } from '@lukeed/uuid'; import { type CallToolRequest, CallToolResultSchema, @@ -368,7 +369,12 @@ export class Client< resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, options?: RequestOptions ): PendingRequest { - return this.beginRequest({ method: 'tools/call', params }, resultSchema, options); + // Automatically add task metadata if not provided + const optionsWithTask = { + ...options, + task: options?.task ?? { taskId: uuidv4() } + }; + return this.beginRequest({ method: 'tools/call', params }, resultSchema, optionsWithTask); } /** diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index a84e5a0ec..b3f1d4e5e 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -1,5 +1,15 @@ import { ZodType, z } from 'zod'; -import { ClientCapabilities, ErrorCode, McpError, Notification, Request, Result, ServerCapabilities } from '../types.js'; +import { + ClientCapabilities, + ErrorCode, + McpError, + Notification, + RELATED_TASK_META_KEY, + Request, + Result, + ServerCapabilities, + TASK_META_KEY +} from '../types.js'; import { Protocol, mergeCapabilities } from './protocol.js'; import { Transport } from './transport.js'; @@ -784,7 +794,7 @@ describe('Task-based execution', () => { params: { name: 'test-tool', _meta: { - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'my-task-123', keepAlive: 30000 } @@ -824,7 +834,7 @@ describe('Task-based execution', () => { name: 'test-tool', _meta: { customField: 'customValue', - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'my-task-456' } } @@ -879,7 +889,7 @@ describe('Task-based execution', () => { params: { data: 'test', _meta: { - 'modelcontextprotocol.io/related-task': { + [RELATED_TASK_META_KEY]: { taskId: 'parent-task-123' } } @@ -911,7 +921,7 @@ describe('Task-based execution', () => { level: 'info', data: 'test message', _meta: { - 'modelcontextprotocol.io/related-task': { + [RELATED_TASK_META_KEY]: { taskId: 'parent-task-456' } } @@ -950,10 +960,10 @@ describe('Task-based execution', () => { params: { name: 'test-tool', _meta: { - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'my-task-combined' }, - 'modelcontextprotocol.io/related-task': { + [RELATED_TASK_META_KEY]: { taskId: 'parent-task' }, progressToken: expect.any(Number) @@ -994,7 +1004,7 @@ describe('Task-based execution', () => { method: 'test/method', params: { _meta: { - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'test-task', keepAlive: 60000 } @@ -1066,7 +1076,7 @@ describe('Task-based execution', () => { method: 'test/method', params: { _meta: { - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'test-task', keepAlive: 60000 } @@ -1122,7 +1132,7 @@ describe('Task-based execution', () => { method: 'test/method', params: { _meta: { - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'test-task', keepAlive: 60000 } @@ -1201,7 +1211,7 @@ describe('Task-based execution', () => { method: 'test/method', params: { _meta: { - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'test-task', keepAlive: 60000 } @@ -1398,7 +1408,11 @@ describe('Task-based execution', () => { result: { tasks: [{ taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }], nextCursor: undefined, - _meta: {} + _meta: { + [TASK_META_KEY]: expect.objectContaining({ + taskId: expect.any(String) + }) + } } }); }, 10); @@ -1429,7 +1443,11 @@ describe('Task-based execution', () => { result: { tasks: [{ taskId: 'task-11', status: 'working', keepAlive: 30000, pollFrequency: 1000 }], nextCursor: 'task-11', - _meta: {} + _meta: { + [TASK_META_KEY]: expect.objectContaining({ + taskId: expect.any(String) + }) + } } }); }, 10); @@ -1439,7 +1457,9 @@ describe('Task-based execution', () => { expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ method: 'tasks/list', - params: { cursor: 'task-10' } + params: { + cursor: 'task-10' + } }), expect.any(Object) ); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index a25107613..40f346edc 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -223,7 +223,7 @@ export abstract class Protocol = new Map(); private _timeoutInfo: Map = new Map(); private _pendingDebouncedNotifications = new Set(); - private _pendingTaskCreations: Map void; reject: (reason: Error) => void }> = new Map(); + private _pendingTaskCreations: Map void; reject: (reason: unknown) => void }> = new Map(); private _requestIdToTaskId: Map = new Map(); private _taskStore?: TaskStore; @@ -400,6 +400,16 @@ export abstract class Protocol>((resolve, reject) => { + const earlyReject = (error: unknown) => { + // Clean up task tracking if we reject before sending + if (taskId) { + const resolver = this._pendingTaskCreations.get(taskId); + resolver?.reject(error); + this._pendingTaskCreations.delete(taskId); + } + reject(error); + }; + if (!this._transport) { - reject(new Error('Not connected')); + earlyReject(new Error('Not connected')); return; } if (this._options?.enforceStrictCapabilities === true) { - this.assertCapabilityForMethod(request.method); + try { + this.assertCapabilityForMethod(request.method); + } catch (e) { + earlyReject(e); + return; + } } options?.signal?.throwIfAborted(); @@ -782,6 +812,7 @@ export abstract class Protocol `${e}`).join(', ')}`); + // Both failed - prefer to throw the result error since it's usually more meaningful + // (e.g., timeout, connection error, etc.) than the task creation failure + throw result.reason; }); } From 2167b437cc6dcf62e7dc644438af446a7c43158d Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Fri, 31 Oct 2025 18:13:03 -0700 Subject: [PATCH 13/84] Implement task API tests on both the client and server --- src/client/index.test.ts | 688 +++++++++++++++++++++++++++++++++++++++ src/server/index.test.ts | 418 ++++++++++++++++++++++++ 2 files changed, 1106 insertions(+) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index de37b2d90..a135a7c14 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -13,6 +13,7 @@ import { ListResourcesRequestSchema, ListToolsRequestSchema, CallToolRequestSchema, + CallToolResultSchema, CreateMessageRequestSchema, ElicitRequestSchema, ListRootsRequestSchema, @@ -21,6 +22,7 @@ import { import { Transport } from '../shared/transport.js'; import { Server } from '../server/index.js'; import { InMemoryTransport } from '../inMemory.js'; +import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; /*** * Test: Initialize with Matching Protocol Version @@ -1239,3 +1241,689 @@ describe('outputSchema validation', () => { ); }); }); + +describe('Task-based execution', () => { + describe('Client calling server', () => { + let serverTaskStore: InMemoryTaskStore; + + beforeEach(() => { + serverTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + serverTaskStore?.cleanup(); + }); + + test('should create task on server via tool call', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: 'Tool executed successfully!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Client creates task on server via tool call + const taskId = 'test-task-create'; + const pendingRequest = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { + taskId, + keepAlive: 60000 + } + }); + + await pendingRequest.result(); + + // Verify task was created successfully + const task = await client.getTask({ taskId }); + expect(task.status).toBe('completed'); + }); + + test('should query task status from server using getTask', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: 'Success!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create a task + const taskId = 'test-task-get'; + const pending = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { taskId, keepAlive: 60000 } + }); + await pending.result(); + + // Query task status + const task = await client.getTask({ taskId }); + expect(task).toBeDefined(); + expect(task.taskId).toBe(taskId); + expect(task.status).toBe('completed'); + }); + + test('should query task result from server using getTaskResult', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: 'Result data!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create a task + const taskId = 'test-task-result'; + const pending = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { taskId, keepAlive: 60000 } + }); + await pending.result(); + + // Query task result + const result = await client.getTaskResult({ taskId }, CallToolResultSchema); + expect(result.content).toEqual([{ type: 'text', text: 'Result data!' }]); + }); + + test('should query task list from server using listTasks', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: 'Success!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks + const taskIds = ['task-list-1', 'task-list-2']; + + for (const taskId of taskIds) { + const pending = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { taskId, keepAlive: 60000 } + }); + await pending.result(); + } + + // Query task list + const taskList = await client.listTasks(); + expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); + for (const taskId of taskIds) { + expect(taskList.tasks).toContainEqual( + expect.objectContaining({ + taskId, + status: 'completed' + }) + ); + } + }); + }); + + describe('Server calling client', () => { + let clientTaskStore: InMemoryTaskStore; + + beforeEach(() => { + clientTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + clientTaskStore?.cleanup(); + }); + + test('should create task on client via server elicitation', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { + username: 'test-user' + } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Server creates task on client via elicitation + const taskId = 'elicit-task-create'; + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const pendingRequest = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' } + }, + required: ['username'] + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + + await pendingRequest.result(); + + // Verify task was created + const task = await server.getTask({ taskId }); + expect(task.status).toBe('completed'); + }); + + test('should query task status from client using getTask', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'test-user' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create a task on client + const taskId = 'elicit-task-get'; + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const pending = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + await pending.result(); + + // Query task status + const task = await server.getTask({ taskId }); + expect(task).toBeDefined(); + expect(task.taskId).toBe(taskId); + expect(task.status).toBe('completed'); + }); + + test('should query task result from client using getTaskResult', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'result-user' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create a task on client + const taskId = 'elicit-task-result'; + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const pending = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + await pending.result(); + + // Query task result + const result = await server.getTaskResult({ taskId }, ElicitResultSchema); + expect(result.action).toBe('accept'); + expect(result.content).toEqual({ username: 'result-user' }); + }); + + test('should query task list from client using listTasks', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'list-user' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks on client + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const taskIds = ['elicit-list-1', 'elicit-list-2']; + for (const taskId of taskIds) { + const pending = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + await pending.result(); + } + + // Query task list + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); + for (const taskId of taskIds) { + expect(taskList.tasks).toContainEqual( + expect.objectContaining({ + taskId, + status: 'completed' + }) + ); + } + }); + }); + + test('should list tasks from server with pagination', async () => { + const serverTaskStore = new InMemoryTaskStore(); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: `Result for ${request.params.arguments?.id || 'unknown'}` }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks + const taskIds = ['task-1', 'task-2', 'task-3']; + + for (const taskId of taskIds) { + const pending = client.beginCallTool({ name: 'test-tool', arguments: { id: taskId } }, CallToolResultSchema, { + task: { taskId, keepAlive: 60000 } + }); + await pending.result(); + } + + // List all tasks without cursor + const firstPage = await client.listTasks(); + expect(firstPage.tasks.length).toBeGreaterThan(0); + expect(firstPage.tasks.map(t => t.taskId)).toEqual(expect.arrayContaining(taskIds)); + + // If there's a cursor, test pagination + if (firstPage.nextCursor) { + const secondPage = await client.listTasks({ cursor: firstPage.nextCursor }); + expect(secondPage.tasks).toBeDefined(); + } + + serverTaskStore.cleanup(); + }); + + describe('Error scenarios', () => { + let serverTaskStore: InMemoryTaskStore; + let clientTaskStore: InMemoryTaskStore; + + beforeEach(() => { + serverTaskStore = new InMemoryTaskStore(); + clientTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + serverTaskStore?.cleanup(); + clientTaskStore?.cleanup(); + }); + + test('should throw error when querying non-existent task from server', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to get a task that doesn't exist + await expect(client.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + }); + + test('should throw error when querying result of non-existent task from server', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to get result of a task that doesn't exist + await expect(client.getTaskResult({ taskId: 'non-existent-task' }, CallToolResultSchema)).rejects.toThrow(); + }); + + test('should throw error when server queries non-existent task from client', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'test' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to query a task that doesn't exist on client + await expect(server.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + }); + }); +}); diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 988085199..0bf13eaf7 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -1211,4 +1211,422 @@ describe('Task-based execution', () => { // Cleanup taskStore.cleanup(); }); + + describe('Server calling client via elicitation', () => { + let clientTaskStore: InMemoryTaskStore; + + beforeEach(() => { + clientTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + clientTaskStore?.cleanup(); + }); + + test('should create task on client via elicitation', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { + username: 'server-test-user', + confirmed: true + } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Server creates task on client via elicitation + const taskId = 'server-elicit-create'; + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const pendingRequest = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' }, + confirmed: { type: 'boolean' } + }, + required: ['username'] + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + + await pendingRequest.result(); + + // Verify task was created + const task = await server.getTask({ taskId }); + expect(task.status).toBe('completed'); + }); + + test('should query task from client using getTask', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'get-user' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create task + const taskId = 'server-elicit-get'; + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const pending = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + await pending.result(); + + // Query task + const task = await server.getTask({ taskId }); + expect(task).toBeDefined(); + expect(task.taskId).toBe(taskId); + expect(task.status).toBe('completed'); + }); + + test('should query task result from client using getTaskResult', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'result-user', confirmed: true } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create task + const taskId = 'server-elicit-result'; + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const pending = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Provide info', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' }, + confirmed: { type: 'boolean' } + } + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + await pending.result(); + + // Query result + const result = await server.getTaskResult({ taskId }, ElicitResultSchema); + expect(result.action).toBe('accept'); + expect(result.content).toEqual({ username: 'result-user', confirmed: true }); + }); + + test('should query task list from client using listTasks', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'list-user' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const taskIds = ['server-elicit-list-1', 'server-elicit-list-2']; + for (const taskId of taskIds) { + const pending = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + await pending.result(); + } + + // Query task list + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); + for (const taskId of taskIds) { + expect(taskList.tasks).toContainEqual( + expect.objectContaining({ + taskId, + status: 'completed' + }) + ); + } + }); + }); + + test('should handle multiple concurrent task-based tool calls', async () => { + const taskStore = new InMemoryTaskStore(); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore + } + ); + + // Set up a tool handler with variable delay + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'async-tool') { + const delay = (request.params.arguments?.delay as number) || 10; + await new Promise(resolve => setTimeout(resolve, delay)); + return { + content: [{ type: 'text', text: `Completed task ${request.params.arguments?.taskNum || 'unknown'}` }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'async-tool', + description: 'An async test tool', + inputSchema: { + type: 'object', + properties: { + delay: { type: 'number' }, + taskNum: { type: 'number' } + } + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks concurrently + const taskIds = ['concurrent-1', 'concurrent-2', 'concurrent-3', 'concurrent-4']; + const pendingRequests = taskIds.map((taskId, index) => + client.beginCallTool({ name: 'async-tool', arguments: { delay: 10 + index * 5, taskNum: index + 1 } }, CallToolResultSchema, { + task: { taskId, keepAlive: 60000 } + }) + ); + + // Wait for all tasks to complete + await Promise.all(pendingRequests.map(p => p.result())); + + // Verify all tasks completed successfully + for (let i = 0; i < taskIds.length; i++) { + const task = await client.getTask({ taskId: taskIds[i] }); + expect(task.status).toBe('completed'); + expect(task.taskId).toBe(taskIds[i]); + + const result = await client.getTaskResult({ taskId: taskIds[i] }, CallToolResultSchema); + expect(result.content).toEqual([{ type: 'text', text: `Completed task ${i + 1}` }]); + } + + // Verify listTasks returns all tasks + const taskList = await client.listTasks(); + for (const taskId of taskIds) { + expect(taskList.tasks).toContainEqual(expect.objectContaining({ taskId })); + } + + // Cleanup + taskStore.cleanup(); + }); + + describe('Error scenarios', () => { + let taskStore: InMemoryTaskStore; + let clientTaskStore: InMemoryTaskStore; + + beforeEach(() => { + taskStore = new InMemoryTaskStore(); + clientTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + taskStore?.cleanup(); + clientTaskStore?.cleanup(); + }); + + test('should throw error when client queries non-existent task from server', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore + } + ); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to query a task that doesn't exist + await expect(client.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + }); + + test('should throw error when server queries non-existent task from client', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'test' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to query a task that doesn't exist on client + await expect(server.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + }); + }); }); From 12d0f66ebd87b346a07251b9d6f9db7f2d0ff041 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Fri, 31 Oct 2025 19:38:40 -0700 Subject: [PATCH 14/84] Make default task polling interval configurable --- src/shared/protocol.ts | 7 ++++++- src/shared/request.ts | 9 ++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 40f346edc..389abfe8f 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -72,6 +72,11 @@ export type ProtocolOptions = { * handle task creation, status tracking, and result storage. */ taskStore?: TaskStore; + /** + * Default polling interval (in milliseconds) for task status checks when no pollFrequency + * is provided by the server. Defaults to 5000ms if not specified. + */ + defaultTaskPollInterval?: number; }; /** @@ -863,7 +868,7 @@ export abstract class Protocol Promise.resolve(); @@ -18,7 +18,8 @@ export class PendingRequest, readonly resultHandle: Promise, readonly resultSchema: ZodType, - readonly taskId?: string + readonly taskId?: string, + readonly defaultTaskPollInterval?: number ) {} /** @@ -66,7 +67,9 @@ export class PendingRequest setTimeout(resolve, task.pollFrequency ?? DEFAULT_POLLING_INTERNAL)); + await new Promise(resolve => + setTimeout(resolve, task.pollFrequency ?? this.defaultTaskPollInterval ?? DEFAULT_TASK_POLLING_INTERVAL) + ); } while (!isTerminal(task.status)); // Process result From bb28ef79808670f5263ac5223e60db51d2a16c27 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 3 Nov 2025 10:52:37 -0800 Subject: [PATCH 15/84] Exclude relatedTask from RequestHandlerExtra --- src/shared/protocol.ts | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 389abfe8f..d3687311d 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -145,6 +145,12 @@ export type NotificationOptions = { relatedTask?: RelatedTaskMetadata; }; +/** + * Options that can be given per request. + */ +// relatedTask is excluded as the SDK controls if this is sent according to if the source is a task. +export type TaskRequestOptions = Omit; + /** * Extra data given to request handlers. */ @@ -196,7 +202,11 @@ export type RequestHandlerExtra< * * This is used by certain transports to correctly associate related messages. */ - sendRequest: >(request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>; + sendRequest: >( + request: SendRequestT, + resultSchema: U, + options?: TaskRequestOptions + ) => Promise>; }; /** From 0bf2b429d27e1a7981a38462dba811a2b0f1ccc9 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 3 Nov 2025 11:10:39 -0800 Subject: [PATCH 16/84] Mark tasks as cancelled only after confirming abort --- src/shared/protocol.test.ts | 290 +++++++++++++++++++++++------------- src/shared/protocol.ts | 77 +++++++--- 2 files changed, 249 insertions(+), 118 deletions(-) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index b3f1d4e5e..76782f3cd 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -6,12 +6,16 @@ import { Notification, RELATED_TASK_META_KEY, Request, + RequestId, Result, ServerCapabilities, - TASK_META_KEY + Task, + TASK_META_KEY, + TaskMetadata } from '../types.js'; import { Protocol, mergeCapabilities } from './protocol.js'; import { Transport } from './transport.js'; +import { TaskStore } from './task.js'; // Mock Transport class class MockTransport implements Transport { @@ -26,6 +30,76 @@ class MockTransport implements Transport { async send(_message: unknown): Promise {} } +function createMockTaskStore(options?: { + onStatus?: (status: Task['status']) => void; + onList?: () => void; +}): TaskStore & { [K in keyof TaskStore]: jest.Mock, Parameters> } { + const tasks: Record = {}; + return { + createTask: jest.fn((taskMetadata: TaskMetadata, _1: RequestId, _2: Request) => { + tasks[taskMetadata.taskId] = { + taskId: taskMetadata.taskId, + status: (taskMetadata.status as Task['status'] | undefined) ?? 'submitted', + keepAlive: taskMetadata.keepAlive ?? null, + pollFrequency: (taskMetadata.pollFrequency as Task['pollFrequency'] | undefined) ?? 1000 + }; + options?.onStatus?.('submitted'); + return Promise.resolve(); + }), + getTask: jest.fn((taskId: string) => { + return Promise.resolve(tasks[taskId] ?? null); + }), + updateTaskStatus: jest.fn((taskId, status, error) => { + const task = tasks[taskId]; + if (task) { + task.status = status; + task.error = error; + options?.onStatus?.(task.status); + } + return Promise.resolve(); + }), + storeTaskResult: jest.fn((taskId: string, result: Result) => { + const task = tasks[taskId]; + if (task) { + task.status = 'completed'; + task.result = result; + options?.onStatus?.('completed'); + } + return Promise.resolve(); + }), + getTaskResult: jest.fn((taskId: string) => { + const task = tasks[taskId]; + if (task?.result) { + return Promise.resolve(task.result); + } + throw new Error('Task result not found'); + }), + listTasks: jest.fn(() => { + const result = { + tasks: Object.values(tasks) + }; + options?.onList?.(); + return Promise.resolve(result); + }) + }; +} + +function createLatch() { + let latch = false; + const waitForLatch = async () => { + while (!latch) { + await new Promise(resolve => setTimeout(resolve, 0)); + } + }; + + return { + releaseLatch: () => { + latch = true; + }, + waitForLatch + }; +} + describe('protocol tests', () => { let protocol: Protocol; let transport: MockTransport; @@ -977,14 +1051,14 @@ describe('Task-based execution', () => { describe('task status transitions', () => { it('should transition from submitted to working when handler starts', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) - }; + const workingProcessed = createLatch(); + const mockTaskStore = createMockTaskStore({ + onStatus: status => { + if (status === 'working') { + workingProcessed.releaseLatch(); + } + } + }); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1012,24 +1086,17 @@ describe('Task-based execution', () => { } }); - await new Promise(resolve => setTimeout(resolve, 50)); + await workingProcessed.waitForLatch(); expect(mockTaskStore.createTask).toHaveBeenCalledWith({ taskId: 'test-task', keepAlive: 60000 }, 1, { method: 'test/method', params: expect.any(Object) }); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working'); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working', undefined); }); it('should transition to input_required during extra.sendRequest', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) - }; + const mockTaskStore = createMockTaskStore(); const responsiveTransport = new MockTransport(); responsiveTransport.send = jest.fn().mockImplementation(async (message: unknown) => { @@ -1105,14 +1172,14 @@ describe('Task-based execution', () => { }); it('should mark task as completed when storeTaskResult is called', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) - }; + const completeProcessed = createLatch(); + const mockTaskStore = createMockTaskStore({ + onStatus: status => { + if (status === 'completed') { + completeProcessed.releaseLatch(); + } + } + }); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1140,20 +1207,20 @@ describe('Task-based execution', () => { } }); - await new Promise(resolve => setTimeout(resolve, 50)); + await completeProcessed.waitForLatch(); expect(mockTaskStore.storeTaskResult).toHaveBeenCalledWith('test-task', { result: 'success' }); }); it('should mark task as cancelled when notifications/cancelled is received', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue({ taskId: 'test-task', status: 'working', keepAlive: null, pollFrequency: 500 }), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) - }; + const cancelProcessed = createLatch(); + const mockTaskStore = createMockTaskStore({ + onStatus: status => { + if (status === 'cancelled') { + cancelProcessed.releaseLatch(); + } + } + }); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1163,35 +1230,50 @@ describe('Task-based execution', () => { await protocol.connect(transport); - void protocol.request({ method: 'test/slow', params: {} }, z.object({ result: z.string() }), { - task: { taskId: 'test-task', keepAlive: 60000 } + const requestInProgress = createLatch(); + const cancelSent = createLatch(); + + protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async () => { + requestInProgress.releaseLatch(); + await cancelSent.waitForLatch(); + return { + result: 'success' + }; }); - await new Promise(resolve => setTimeout(resolve, 10)); + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'test/method', + params: { + _meta: { + [TASK_META_KEY]: { + taskId: 'test-task', + keepAlive: 60000 + } + } + } + }); transport.onmessage?.({ jsonrpc: '2.0', method: 'notifications/cancelled', params: { - requestId: 0, + requestId: 1, reason: 'User cancelled' } }); - await new Promise(resolve => setTimeout(resolve, 10)); + await requestInProgress.waitForLatch(); + cancelSent.releaseLatch(); + await cancelProcessed.waitForLatch(); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'cancelled'); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'cancelled', undefined); }); it('should mark task as failed when updateTaskStatus to working fails', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockRejectedValueOnce(new Error('Failed to update status')).mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) - }; + const mockTaskStore = createMockTaskStore(); + mockTaskStore.updateTaskStatus.mockRejectedValueOnce(new Error('Failed to update status')).mockResolvedValue(undefined); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1221,27 +1303,42 @@ describe('Task-based execution', () => { await new Promise(resolve => setTimeout(resolve, 50)); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working'); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working', undefined); expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'failed', 'Failed to mark task as working'); }); }); describe('listTasks', () => { it('should handle tasks/list requests and return tasks from TaskStore', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ - tasks: [ - { taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }, - { taskId: 'task-2', status: 'working', keepAlive: 60000, pollFrequency: 1000 } - ], - nextCursor: 'task-2' - }) - }; + const listedTasks = createLatch(); + const mockTaskStore = createMockTaskStore({ + onList: () => listedTasks.releaseLatch() + }); + await mockTaskStore.createTask( + { + taskId: 'task-1', + status: 'completed', + pollFrequency: 500 + }, + 1, + { + method: 'test/method', + params: {} + } + ); + await mockTaskStore.createTask( + { + taskId: 'task-2', + status: 'working', + keepAlive: 60000, + pollFrequency: 1000 + }, + 2, + { + method: 'test/method', + params: {} + } + ); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1254,37 +1351,41 @@ describe('Task-based execution', () => { // Simulate receiving a tasks/list request transport.onmessage?.({ jsonrpc: '2.0', - id: 1, + id: 3, method: 'tasks/list', params: {} }); - await new Promise(resolve => setTimeout(resolve, 10)); + await listedTasks.waitForLatch(); expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined); const sentMessage = sendSpy.mock.calls[0][0]; expect(sentMessage.jsonrpc).toBe('2.0'); - expect(sentMessage.id).toBe(1); + expect(sentMessage.id).toBe(3); expect(sentMessage.result.tasks).toEqual([ { taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }, { taskId: 'task-2', status: 'working', keepAlive: 60000, pollFrequency: 1000 } ]); - expect(sentMessage.result.nextCursor).toBe('task-2'); expect(sentMessage.result._meta).toEqual({}); }); it('should handle tasks/list requests with cursor for pagination', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ - tasks: [{ taskId: 'task-3', status: 'submitted', keepAlive: null, pollFrequency: 500 }], - nextCursor: undefined - }) - }; + const listedTasks = createLatch(); + const mockTaskStore = createMockTaskStore({ + onList: () => listedTasks.releaseLatch() + }); + await mockTaskStore.createTask( + { + taskId: 'task-3', + status: 'submitted', + pollFrequency: 500 + }, + 1, + { + method: 'test/method', + params: {} + } + ); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1304,7 +1405,7 @@ describe('Task-based execution', () => { } }); - await new Promise(resolve => setTimeout(resolve, 10)); + await listedTasks.waitForLatch(); expect(mockTaskStore.listTasks).toHaveBeenCalledWith('task-2'); const sentMessage = sendSpy.mock.calls[0][0]; @@ -1316,17 +1417,10 @@ describe('Task-based execution', () => { }); it('should handle tasks/list requests with empty results', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ - tasks: [], - nextCursor: undefined - }) - }; + const listedTasks = createLatch(); + const mockTaskStore = createMockTaskStore({ + onList: () => listedTasks.releaseLatch() + }); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1344,7 +1438,7 @@ describe('Task-based execution', () => { params: {} }); - await new Promise(resolve => setTimeout(resolve, 10)); + await listedTasks.waitForLatch(); expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined); const sentMessage = sendSpy.mock.calls[0][0]; @@ -1356,14 +1450,8 @@ describe('Task-based execution', () => { }); it('should return error for invalid cursor', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockRejectedValue(new Error('Invalid cursor: bad-cursor')) - }; + const mockTaskStore = createMockTaskStore(); + mockTaskStore.listTasks.mockRejectedValue(new Error('Invalid cursor: bad-cursor')); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index d3687311d..ed407be56 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -36,12 +36,14 @@ import { TASK_META_KEY, GetTaskResult, TaskMetadata, - RelatedTaskMetadata + RelatedTaskMetadata, + CancelledNotification, + Task } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; import { PendingRequest } from './request.js'; -import { TaskStore } from './task.js'; +import { isTerminal, TaskStore } from './task.js'; /** * Callback for progress notifications. @@ -239,7 +241,7 @@ export abstract class Protocol = new Map(); private _pendingDebouncedNotifications = new Set(); private _pendingTaskCreations: Map void; reject: (reason: unknown) => void }> = new Map(); - private _requestIdToTaskId: Map = new Map(); + private _requestIdToTaskId: Map = new Map(); private _taskStore?: TaskStore; /** @@ -356,16 +358,18 @@ export abstract class Protocol): Promise { + private async _oncancel(notification: CancelledNotification): Promise { // Handle request cancellation const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); controller?.abort(notification.params.reason); + } + private async _postcancel(requestId: RequestId): Promise { // If this request had a task, mark it as cancelled in storage - const taskId = this._requestIdToTaskId.get(Number(notification.params.requestId)); + const taskId = this._requestIdToTaskId.get(requestId); if (taskId && this._taskStore) { try { - await this._taskStore.updateTaskStatus(taskId, 'cancelled'); + await this._setTaskStatus(taskId, 'cancelled'); } catch (error) { this._onerror(new Error(`Failed to cancel task ${taskId}: ${error}`)); } @@ -536,14 +540,16 @@ export abstract class Protocol { const relatedTask = taskMetadata ? { taskId: taskMetadata.taskId } : undefined; - if (taskMetadata && this._taskStore) { - await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'input_required'); + if (taskMetadata) { + // Allow this to throw to the caller (request handler) + await this._setTaskStatus(taskMetadata.taskId, 'input_required'); } try { return await this.request(r, resultSchema, { ...options, relatedRequestId: request.id, relatedTask }); } finally { - if (taskMetadata && this._taskStore) { - await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'working'); + if (taskMetadata) { + // Allow this to throw to the caller (request handler) + await this._setTaskStatus(taskMetadata.taskId, 'working'); } } }, @@ -557,7 +563,7 @@ export abstract class Protocol { // If this request asked for task creation, create the task and send notification if (taskMetadata && this._taskStore) { - const task = await this._taskStore!.getTask(taskMetadata.taskId); + const task = await this._taskStore.getTask(taskMetadata.taskId); if (task) { throw new McpError(ErrorCode.InvalidParams, `Task ID already exists: ${taskMetadata.taskId}`); } @@ -566,6 +572,7 @@ export abstract class Protocol { // If this request had a task, mark it as working - if (taskMetadata && this._taskStore) { + if (taskMetadata) { try { - await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'working'); + await this._setTaskStatus(taskMetadata.taskId, 'working'); } catch { try { - await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'failed', 'Failed to mark task as working'); + await this._setTaskStatus(taskMetadata.taskId, 'failed', 'Failed to mark task as working'); } catch (error) { throw new McpError(ErrorCode.InternalError, `Failed to mark task as working: ${error}`); } @@ -601,6 +608,8 @@ export abstract class Protocol { if (abortController.signal.aborted) { + // Request was cancelled + await this._postcancel(request.id); return; } @@ -620,8 +629,10 @@ export abstract class Protocol { + async error => { if (abortController.signal.aborted) { + // Request was cancelled + await this._postcancel(request.id); return; } @@ -749,7 +760,7 @@ export abstract class Protocol>((resolve, reject) => { @@ -895,12 +906,44 @@ export abstract class Protocol { + private _waitForTaskCreation(taskId: string): Promise { return new Promise((resolve, reject) => { this._pendingTaskCreations.set(taskId, { resolve, reject }); }); } + private async _setTaskStatus( + taskId: string, + status: Status, + errorReason?: ErrorReason + ) { + if (!this._taskStore) { + // No task store configured + return; + } + + try { + // Check the current task status to avoid overwriting terminal states + // as a safeguard for when the TaskStore implementation doesn't try + // to avoid this. + const task = await this._taskStore.getTask(taskId); + if (!task) { + return; + } + + if (isTerminal(task.status)) { + this._onerror( + new Error(`Failed to update status of task "${taskId}" from terminal status "${task.status}" to "${status}"`) + ); + return; + } + + await this._taskStore.updateTaskStatus(taskId, status, errorReason); + } catch (error) { + throw new Error(`Failed to update status of task "${taskId}" to "${status}": ${error}`); + } + } + /** * Gets the current status of a task. */ From 486e8edba532637f2ac50eb68ddcca2b4850ca1d Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 3 Nov 2025 15:02:52 -0800 Subject: [PATCH 17/84] Store task result before attempting to respond to client --- src/shared/protocol.ts | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index ed407be56..bf28afde7 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -613,14 +613,10 @@ export abstract class Protocol { if (abortController.signal.aborted) { From 06db60370c42d4ccc52c7531ec54c68ea7180d28 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 3 Nov 2025 15:54:44 -0800 Subject: [PATCH 18/84] Allow task polling before creation notification arrives --- src/shared/request.ts | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/shared/request.ts b/src/shared/request.ts index d6d48467f..ddfd1c1c6 100644 --- a/src/shared/request.ts +++ b/src/shared/request.ts @@ -36,13 +36,21 @@ export class PendingRequest { - // Blocks for a notifications/tasks/created with the provided task ID - await this.taskCreatedHandle; - await onTaskCreated(); - return await this.taskHandler(this.taskId!, { + // Start task handler immediately without waiting for creation notification + const taskPromise = this.taskHandler(this.taskId!, { onTaskCreated, onTaskStatus }); + + // Call onTaskCreated callback when notification arrives, but don't block taskHandler + // The promise is tied to the lifecycle of taskPromise, so it won't leak + this.taskCreatedHandle + .then(() => onTaskCreated()) + .catch(() => { + // Silently ignore if notification never arrives or fails + }); + + return await taskPromise; })(), this.resultHandle ]).then(([task, result]) => { From 723bc7dc964c5efea321158b7f518f86a3f09823 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 5 Nov 2025 13:04:17 -0800 Subject: [PATCH 19/84] Add session ID to TaskStore methods --- src/examples/shared/inMemoryTaskStore.ts | 12 ++--- src/shared/protocol.test.ts | 31 +++++++------ src/shared/protocol.ts | 59 ++++++++++++++---------- src/shared/task.ts | 18 +++++--- 4 files changed, 71 insertions(+), 49 deletions(-) diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index c9f297c86..85b30e781 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -21,7 +21,7 @@ export class InMemoryTaskStore implements TaskStore { private tasks = new Map(); private cleanupTimers = new Map>(); - async createTask(metadata: TaskMetadata, requestId: RequestId, request: Request): Promise { + async createTask(metadata: TaskMetadata, requestId: RequestId, request: Request, _sessionId?: string): Promise { const taskId = metadata.taskId; if (this.tasks.has(taskId)) { @@ -52,12 +52,12 @@ export class InMemoryTaskStore implements TaskStore { } } - async getTask(taskId: string): Promise { + async getTask(taskId: string, _sessionId?: string): Promise { const stored = this.tasks.get(taskId); return stored ? { ...stored.task } : null; } - async storeTaskResult(taskId: string, result: Result): Promise { + async storeTaskResult(taskId: string, result: Result, _sessionId?: string): Promise { const stored = this.tasks.get(taskId); if (!stored) { throw new Error(`Task with ID ${taskId} not found`); @@ -82,7 +82,7 @@ export class InMemoryTaskStore implements TaskStore { } } - async getTaskResult(taskId: string): Promise { + async getTaskResult(taskId: string, _sessionId?: string): Promise { const stored = this.tasks.get(taskId); if (!stored) { throw new Error(`Task with ID ${taskId} not found`); @@ -95,7 +95,7 @@ export class InMemoryTaskStore implements TaskStore { return stored.result; } - async updateTaskStatus(taskId: string, status: Task['status'], error?: string): Promise { + async updateTaskStatus(taskId: string, status: Task['status'], error?: string, _sessionId?: string): Promise { const stored = this.tasks.get(taskId); if (!stored) { throw new Error(`Task with ID ${taskId} not found`); @@ -122,7 +122,7 @@ export class InMemoryTaskStore implements TaskStore { } } - async listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }> { + async listTasks(cursor?: string, _sessionId?: string): Promise<{ tasks: Task[]; nextCursor?: string }> { const PAGE_SIZE = 10; const allTaskIds = Array.from(this.tasks.keys()); diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 76782f3cd..a91930eea 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -1088,11 +1088,16 @@ describe('Task-based execution', () => { await workingProcessed.waitForLatch(); - expect(mockTaskStore.createTask).toHaveBeenCalledWith({ taskId: 'test-task', keepAlive: 60000 }, 1, { - method: 'test/method', - params: expect.any(Object) - }); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working', undefined); + expect(mockTaskStore.createTask).toHaveBeenCalledWith( + { taskId: 'test-task', keepAlive: 60000 }, + 1, + { + method: 'test/method', + params: expect.any(Object) + }, + undefined + ); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working', undefined, undefined); }); it('should transition to input_required during extra.sendRequest', async () => { @@ -1209,7 +1214,7 @@ describe('Task-based execution', () => { await completeProcessed.waitForLatch(); - expect(mockTaskStore.storeTaskResult).toHaveBeenCalledWith('test-task', { result: 'success' }); + expect(mockTaskStore.storeTaskResult).toHaveBeenCalledWith('test-task', { result: 'success' }, undefined); }); it('should mark task as cancelled when notifications/cancelled is received', async () => { @@ -1268,7 +1273,7 @@ describe('Task-based execution', () => { cancelSent.releaseLatch(); await cancelProcessed.waitForLatch(); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'cancelled', undefined); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'cancelled', undefined, undefined); }); it('should mark task as failed when updateTaskStatus to working fails', async () => { @@ -1303,8 +1308,8 @@ describe('Task-based execution', () => { await new Promise(resolve => setTimeout(resolve, 50)); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working', undefined); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'failed', 'Failed to mark task as working'); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working', undefined, undefined); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'failed', 'Failed to mark task as working', undefined); }); }); @@ -1358,7 +1363,7 @@ describe('Task-based execution', () => { await listedTasks.waitForLatch(); - expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined); + expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined, undefined); const sentMessage = sendSpy.mock.calls[0][0]; expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(3); @@ -1407,7 +1412,7 @@ describe('Task-based execution', () => { await listedTasks.waitForLatch(); - expect(mockTaskStore.listTasks).toHaveBeenCalledWith('task-2'); + expect(mockTaskStore.listTasks).toHaveBeenCalledWith('task-2', undefined); const sentMessage = sendSpy.mock.calls[0][0]; expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(2); @@ -1440,7 +1445,7 @@ describe('Task-based execution', () => { await listedTasks.waitForLatch(); - expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined); + expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined, undefined); const sentMessage = sendSpy.mock.calls[0][0]; expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(3); @@ -1473,7 +1478,7 @@ describe('Task-based execution', () => { await new Promise(resolve => setTimeout(resolve, 10)); - expect(mockTaskStore.listTasks).toHaveBeenCalledWith('bad-cursor'); + expect(mockTaskStore.listTasks).toHaveBeenCalledWith('bad-cursor', undefined); const sentMessage = sendSpy.mock.calls[0][0]; expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(4); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index bf28afde7..87aca8557 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -300,8 +300,8 @@ export abstract class Protocol { - const task = await this._taskStore!.getTask(request.params.taskId); + this.setRequestHandler(GetTaskRequestSchema, async (request, extra) => { + const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); if (!task) { throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); } @@ -317,8 +317,8 @@ export abstract class Protocol { - const task = await this._taskStore!.getTask(request.params.taskId); + this.setRequestHandler(GetTaskPayloadRequestSchema, async (request, extra) => { + const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); if (!task) { throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); } @@ -327,7 +327,7 @@ export abstract class Protocol { + this.setRequestHandler(ListTasksRequestSchema, async (request, extra) => { try { - const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor); + const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, extra.sessionId); // @ts-expect-error SendResultT cannot contain ListTasksResult, but we include it in our derived types everywhere else return { tasks, @@ -364,12 +364,12 @@ export abstract class Protocol { + private async _postcancel(requestId: RequestId, sessionId?: string): Promise { // If this request had a task, mark it as cancelled in storage const taskId = this._requestIdToTaskId.get(requestId); if (taskId && this._taskStore) { try { - await this._setTaskStatus(taskId, 'cancelled'); + await this._setTaskStatus(taskId, 'cancelled', undefined, sessionId); } catch (error) { this._onerror(new Error(`Failed to cancel task ${taskId}: ${error}`)); } @@ -542,14 +542,14 @@ export abstract class Protocol { // If this request asked for task creation, create the task and send notification if (taskMetadata && this._taskStore) { - const task = await this._taskStore.getTask(taskMetadata.taskId); + const task = await this._taskStore.getTask(taskMetadata.taskId, capturedTransport?.sessionId); if (task) { throw new McpError(ErrorCode.InvalidParams, `Task ID already exists: ${taskMetadata.taskId}`); } - await this._taskStore.createTask(taskMetadata, request.id, { - method: request.method, - params: request.params - }); + await this._taskStore.createTask( + taskMetadata, + request.id, + { + method: request.method, + params: request.params + }, + capturedTransport?.sessionId + ); this._requestIdToTaskId.set(request.id, taskMetadata.taskId); // Send task created notification @@ -594,10 +599,15 @@ export abstract class Protocol { if (abortController.signal.aborted) { // Request was cancelled - await this._postcancel(request.id); + await this._postcancel(request.id, capturedTransport?.sessionId); return; } @@ -619,7 +629,7 @@ export abstract class Protocol { if (abortController.signal.aborted) { // Request was cancelled - await this._postcancel(request.id); + await this._postcancel(request.id, capturedTransport?.sessionId); return; } @@ -918,7 +928,8 @@ export abstract class Protocol( taskId: string, status: Status, - errorReason?: ErrorReason + errorReason?: ErrorReason, + sessionId?: string ) { if (!this._taskStore) { // No task store configured @@ -929,7 +940,7 @@ export abstract class Protocol; + createTask(task: TaskMetadata, requestId: RequestId, request: Request, sessionId?: string): Promise; /** * Gets the current status of a task. * * @param taskId - The task identifier + * @param sessionId - Optional session ID for binding the query to a specific session * @returns The task state including status, keepAlive, pollFrequency, and optional error */ - getTask(taskId: string): Promise; + getTask(taskId: string, sessionId?: string): Promise; /** * Stores the result of a completed task. * * @param taskId - The task identifier * @param result - The result to store + * @param sessionId - Optional session ID for binding the operation to a specific session */ - storeTaskResult(taskId: string, result: Result): Promise; + storeTaskResult(taskId: string, result: Result, sessionId?: string): Promise; /** * Retrieves the stored result of a task. * * @param taskId - The task identifier + * @param sessionId - Optional session ID for binding the query to a specific session * @returns The stored result */ - getTaskResult(taskId: string): Promise; + getTaskResult(taskId: string, sessionId?: string): Promise; /** * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). @@ -46,16 +50,18 @@ export interface TaskStore { * @param taskId - The task identifier * @param status - The new status * @param error - Optional error message if status is 'failed' or 'cancelled' + * @param sessionId - Optional session ID for binding the operation to a specific session */ - updateTaskStatus(taskId: string, status: Task['status'], error?: string): Promise; + updateTaskStatus(taskId: string, status: Task['status'], error?: string, sessionId?: string): Promise; /** * Lists tasks, optionally starting from a pagination cursor. * * @param cursor - Optional cursor for pagination + * @param sessionId - Optional session ID for binding the query to a specific session * @returns An object containing the tasks array and an optional nextCursor */ - listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; + listTasks(cursor?: string, sessionId?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; } /** From 3789080a01e6560ca4a2da4a927f6975dc653382 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 5 Nov 2025 13:25:29 -0800 Subject: [PATCH 20/84] Implement tasks/delete --- src/examples/shared/inMemoryTaskStore.test.ts | 54 ++++++++ src/examples/shared/inMemoryTaskStore.ts | 17 +++ src/shared/protocol.test.ts | 130 ++++++++++++++++++ src/shared/protocol.ts | 24 ++++ src/shared/task.ts | 9 ++ src/types.ts | 17 +++ 6 files changed, 251 insertions(+) diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts index 9c8c7dab0..ccbe731ae 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -444,4 +444,58 @@ describe('InMemoryTaskStore', () => { expect(store.getAllTasks()).toHaveLength(0); }); }); + + describe('deleteTask', () => { + it('should delete an existing task', async () => { + await store.createTask({ taskId: 'task-to-delete' }, 1, { + method: 'tools/call', + params: {} + }); + + expect(await store.getTask('task-to-delete')).toBeDefined(); + + await store.deleteTask('task-to-delete'); + + expect(await store.getTask('task-to-delete')).toBeNull(); + }); + + it('should throw error when deleting non-existent task', async () => { + await expect(store.deleteTask('non-existent')).rejects.toThrow('Task with ID non-existent not found'); + }); + + it('should clear cleanup timer when deleting task with keepAlive', async () => { + jest.useFakeTimers(); + + await store.createTask({ taskId: 'task-with-timer', keepAlive: 1000 }, 1, { + method: 'tools/call', + params: {} + }); + + expect(await store.getTask('task-with-timer')).toBeDefined(); + + await store.deleteTask('task-with-timer'); + + // Fast-forward past keepAlive time + jest.advanceTimersByTime(1001); + + // Task should not exist (it was deleted immediately, not cleaned up by timer) + expect(await store.getTask('task-with-timer')).toBeNull(); + + jest.useRealTimers(); + }); + + it('should delete task with result', async () => { + await store.createTask({ taskId: 'task-with-result' }, 1, { + method: 'tools/call', + params: {} + }); + + const result = { content: [{ type: 'text' as const, text: 'Result' }] }; + await store.storeTaskResult('task-with-result', result); + + await store.deleteTask('task-with-result'); + + expect(await store.getTask('task-with-result')).toBeNull(); + }); + }); }); diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index 85b30e781..7b7979f9a 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -148,6 +148,23 @@ export class InMemoryTaskStore implements TaskStore { return { tasks, nextCursor }; } + async deleteTask(taskId: string, _sessionId?: string): Promise { + const stored = this.tasks.get(taskId); + if (!stored) { + throw new Error(`Task with ID ${taskId} not found`); + } + + // Clear any associated cleanup timer + const existingTimer = this.cleanupTimers.get(taskId); + if (existingTimer) { + clearTimeout(existingTimer); + this.cleanupTimers.delete(taskId); + } + + // Delete the task + this.tasks.delete(taskId); + } + /** * Cleanup all timers (useful for testing or graceful shutdown) */ diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index a91930eea..1ab3b7b7b 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -80,6 +80,13 @@ function createMockTaskStore(options?: { }; options?.onList?.(); return Promise.resolve(result); + }), + deleteTask: jest.fn((taskId: string) => { + if (tasks[taskId]) { + delete tasks[taskId]; + return Promise.resolve(); + } + return Promise.reject(new Error(`Task with ID ${taskId} not found`)); }) }; } @@ -1561,4 +1568,127 @@ describe('Task-based execution', () => { expect(result.nextCursor).toBe('task-11'); }); }); + + describe('deleteTask', () => { + it('should handle tasks/delete requests and delete task from TaskStore', async () => { + const taskDeleted = createLatch(); + const mockTaskStore = createMockTaskStore(); + await mockTaskStore.createTask( + { + taskId: 'task-to-delete' + }, + 1, + { + method: 'test/method', + params: {} + } + ); + + mockTaskStore.deleteTask.mockImplementation(async (taskId: string) => { + if (taskId === 'task-to-delete') { + taskDeleted.releaseLatch(); + return; + } + throw new Error('Task not found'); + }); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = jest.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 5, + method: 'tasks/delete', + params: { + taskId: 'task-to-delete' + } + }); + + await taskDeleted.waitForLatch(); + + expect(mockTaskStore.deleteTask).toHaveBeenCalledWith('task-to-delete', undefined); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const sentMessage = sendSpy.mock.calls[0][0] as any; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(5); + expect(sentMessage.result._meta).toBeDefined(); + }); + + it('should return error with code -32600 when task does not exist', async () => { + const taskDeleted = createLatch(); + const mockTaskStore = createMockTaskStore(); + + mockTaskStore.deleteTask.mockImplementation(async () => { + taskDeleted.releaseLatch(); + throw new Error('Task with ID non-existent not found'); + }); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = jest.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 6, + method: 'tasks/delete', + params: { + taskId: 'non-existent' + } + }); + + await taskDeleted.waitForLatch(); + + expect(mockTaskStore.deleteTask).toHaveBeenCalledWith('non-existent', undefined); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const sentMessage = sendSpy.mock.calls[0][0] as any; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(6); + expect(sentMessage.error).toBeDefined(); + expect(sentMessage.error.code).toBe(-32600); // InvalidRequest error code + expect(sentMessage.error.message).toContain('Failed to delete task'); + }); + + it('should call deleteTask method from client side', async () => { + await protocol.connect(transport); + + const deleteTaskPromise = protocol.deleteTask({ taskId: 'task-to-delete' }); + + // Simulate server response + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: sendSpy.mock.calls[0][0].id, + result: { + _meta: {} + } + }); + }, 0); + + const result = await deleteTaskPromise; + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'tasks/delete', + params: { + taskId: 'task-to-delete' + } + }), + expect.any(Object) + ); + expect(result._meta).toBeDefined(); + }); + }); }); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 87aca8557..eab89395e 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -10,6 +10,8 @@ import { GetTaskPayloadRequestSchema, ListTasksRequestSchema, ListTasksResultSchema, + DeleteTaskRequestSchema, + DeleteTaskResultSchema, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, @@ -355,6 +357,20 @@ export abstract class Protocol { + try { + await this._taskStore!.deleteTask(request.params.taskId, extra.sessionId); + return { + _meta: {} + } as SendResultT; + } catch (error) { + throw new McpError( + ErrorCode.InvalidRequest, + `Failed to delete task: ${error instanceof Error ? error.message : String(error)}` + ); + } + }); } } @@ -986,6 +1002,14 @@ export abstract class Protocol> { + // @ts-expect-error SendRequestT cannot directly contain DeleteTaskRequest, but we ensure all type instantiations contain it anyways + return this.request({ method: 'tasks/delete', params }, DeleteTaskResultSchema, options); + } + /** * Emits a notification, which is a one-way message that does not expect a response. */ diff --git a/src/shared/task.ts b/src/shared/task.ts index 6e933820d..5206c9427 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -62,6 +62,15 @@ export interface TaskStore { * @returns An object containing the tasks array and an optional nextCursor */ listTasks(cursor?: string, sessionId?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; + + /** + * Deletes a specific task and its associated data. + * + * @param taskId - The task identifier + * @param sessionId - Optional session ID for binding the operation to a specific session + * @throws Error if the task doesn't exist or cannot be deleted + */ + deleteTask(taskId: string, sessionId?: string): Promise; } /** diff --git a/src/types.ts b/src/types.ts index 5b528f551..9b406b0ae 100644 --- a/src/types.ts +++ b/src/types.ts @@ -573,6 +573,21 @@ export const ListTasksResultSchema = PaginatedResultSchema.extend({ tasks: z.array(TaskSchema) }); +/** + * A request to delete a specific task. + */ +export const DeleteTaskRequestSchema = RequestSchema.extend({ + method: z.literal('tasks/delete'), + params: BaseRequestParamsSchema.extend({ + taskId: z.string() + }) +}); + +/** + * The response to a tasks/delete request. + */ +export const DeleteTaskResultSchema = ResultSchema; + /* Resources */ /** * The contents of a specific resource or sub-resource. @@ -1696,6 +1711,8 @@ export type GetTaskResult = Infer; export type GetTaskPayloadRequest = Infer; export type ListTasksRequest = Infer; export type ListTasksResult = Infer; +export type DeleteTaskRequest = Infer; +export type DeleteTaskResult = Infer; /* Pagination */ export type PaginatedRequest = Infer; From 9ae5f8462bfda24cd12d8f2015f1e0db12e482b0 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 5 Nov 2025 13:29:28 -0800 Subject: [PATCH 21/84] Rename pollFrequency to pollInterval --- src/examples/shared/inMemoryTaskStore.test.ts | 2 +- src/examples/shared/inMemoryTaskStore.ts | 2 +- src/shared/protocol.test.ts | 18 +++++++++--------- src/shared/protocol.ts | 2 +- src/shared/request.ts | 2 +- src/shared/task.ts | 2 +- src/types.ts | 2 +- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts index ccbe731ae..f6a47a1ba 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -30,7 +30,7 @@ describe('InMemoryTaskStore', () => { expect(task?.taskId).toBe('task-1'); expect(task?.status).toBe('submitted'); expect(task?.keepAlive).toBe(60000); - expect(task?.pollFrequency).toBe(500); + expect(task?.pollInterval).toBe(500); }); it('should create task without keepAlive', async () => { diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index 7b7979f9a..087102f78 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -32,7 +32,7 @@ export class InMemoryTaskStore implements TaskStore { taskId, status: 'submitted', keepAlive: metadata.keepAlive ?? null, - pollFrequency: 500 + pollInterval: 500 }; this.tasks.set(taskId, { diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 1ab3b7b7b..664a0ae75 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -41,7 +41,7 @@ function createMockTaskStore(options?: { taskId: taskMetadata.taskId, status: (taskMetadata.status as Task['status'] | undefined) ?? 'submitted', keepAlive: taskMetadata.keepAlive ?? null, - pollFrequency: (taskMetadata.pollFrequency as Task['pollFrequency'] | undefined) ?? 1000 + pollInterval: (taskMetadata.pollInterval as Task['pollInterval'] | undefined) ?? 1000 }; options?.onStatus?.('submitted'); return Promise.resolve(); @@ -1330,7 +1330,7 @@ describe('Task-based execution', () => { { taskId: 'task-1', status: 'completed', - pollFrequency: 500 + pollInterval: 500 }, 1, { @@ -1343,7 +1343,7 @@ describe('Task-based execution', () => { taskId: 'task-2', status: 'working', keepAlive: 60000, - pollFrequency: 1000 + pollInterval: 1000 }, 2, { @@ -1375,8 +1375,8 @@ describe('Task-based execution', () => { expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(3); expect(sentMessage.result.tasks).toEqual([ - { taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }, - { taskId: 'task-2', status: 'working', keepAlive: 60000, pollFrequency: 1000 } + { taskId: 'task-1', status: 'completed', keepAlive: null, pollInterval: 500 }, + { taskId: 'task-2', status: 'working', keepAlive: 60000, pollInterval: 1000 } ]); expect(sentMessage.result._meta).toEqual({}); }); @@ -1390,7 +1390,7 @@ describe('Task-based execution', () => { { taskId: 'task-3', status: 'submitted', - pollFrequency: 500 + pollInterval: 500 }, 1, { @@ -1423,7 +1423,7 @@ describe('Task-based execution', () => { const sentMessage = sendSpy.mock.calls[0][0]; expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(2); - expect(sentMessage.result.tasks).toEqual([{ taskId: 'task-3', status: 'submitted', keepAlive: null, pollFrequency: 500 }]); + expect(sentMessage.result.tasks).toEqual([{ taskId: 'task-3', status: 'submitted', keepAlive: null, pollInterval: 500 }]); expect(sentMessage.result.nextCursor).toBeUndefined(); expect(sentMessage.result._meta).toEqual({}); }); @@ -1506,7 +1506,7 @@ describe('Task-based execution', () => { jsonrpc: '2.0', id: sendSpy.mock.calls[0][0].id, result: { - tasks: [{ taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }], + tasks: [{ taskId: 'task-1', status: 'completed', keepAlive: null, pollInterval: 500 }], nextCursor: undefined, _meta: { [TASK_META_KEY]: expect.objectContaining({ @@ -1541,7 +1541,7 @@ describe('Task-based execution', () => { jsonrpc: '2.0', id: sendSpy.mock.calls[0][0].id, result: { - tasks: [{ taskId: 'task-11', status: 'working', keepAlive: 30000, pollFrequency: 1000 }], + tasks: [{ taskId: 'task-11', status: 'working', keepAlive: 30000, pollInterval: 1000 }], nextCursor: 'task-11', _meta: { [TASK_META_KEY]: expect.objectContaining({ diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index eab89395e..86324f58b 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -77,7 +77,7 @@ export type ProtocolOptions = { */ taskStore?: TaskStore; /** - * Default polling interval (in milliseconds) for task status checks when no pollFrequency + * Default polling interval (in milliseconds) for task status checks when no pollInterval * is provided by the server. Defaults to 5000ms if not specified. */ defaultTaskPollInterval?: number; diff --git a/src/shared/request.ts b/src/shared/request.ts index ddfd1c1c6..5f1f70478 100644 --- a/src/shared/request.ts +++ b/src/shared/request.ts @@ -76,7 +76,7 @@ export class PendingRequest - setTimeout(resolve, task.pollFrequency ?? this.defaultTaskPollInterval ?? DEFAULT_TASK_POLLING_INTERVAL) + setTimeout(resolve, task.pollInterval ?? this.defaultTaskPollInterval ?? DEFAULT_TASK_POLLING_INTERVAL) ); } while (!isTerminal(task.status)); diff --git a/src/shared/task.ts b/src/shared/task.ts index 5206c9427..3bb09195d 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -22,7 +22,7 @@ export interface TaskStore { * * @param taskId - The task identifier * @param sessionId - Optional session ID for binding the query to a specific session - * @returns The task state including status, keepAlive, pollFrequency, and optional error + * @returns The task state including status, keepAlive, pollInterval, and optional error */ getTask(taskId: string, sessionId?: string): Promise; diff --git a/src/types.ts b/src/types.ts index 9b406b0ae..fff67145e 100644 --- a/src/types.ts +++ b/src/types.ts @@ -503,7 +503,7 @@ export const TaskSchema = z.object({ taskId: z.string(), status: z.enum(['submitted', 'working', 'input_required', 'completed', 'failed', 'cancelled', 'unknown']), keepAlive: z.union([z.number(), z.null()]), - pollFrequency: z.optional(z.number()), + pollInterval: z.optional(z.number()), error: z.optional(z.string()) }); From 719675a0360d2a983918bb3898f939200c32c224 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 5 Nov 2025 15:32:21 -0800 Subject: [PATCH 22/84] Implement capabilities for tasks --- src/client/index.test.ts | 708 +++++++++++++++--- src/client/index.ts | 89 +++ src/server/index.test.ts | 487 ++++++++++-- src/server/index.ts | 71 ++ .../protocol-transport-handling.test.ts | 1 + src/shared/protocol.test.ts | 20 + src/shared/protocol.ts | 73 +- src/types.ts | 135 +++- 8 files changed, 1424 insertions(+), 160 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index a135a7c14..5049bac05 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -833,10 +833,28 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -908,10 +926,28 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -980,10 +1016,28 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1048,10 +1102,28 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1143,10 +1215,28 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1223,10 +1313,28 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1262,7 +1370,16 @@ describe('Task-based execution', () => { }, { capabilities: { - tools: {} + tools: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: serverTaskStore } @@ -1290,10 +1407,28 @@ describe('Task-based execution', () => { ] })); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1323,7 +1458,16 @@ describe('Task-based execution', () => { }, { capabilities: { - tools: {} + tools: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: serverTaskStore } @@ -1351,10 +1495,28 @@ describe('Task-based execution', () => { ] })); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1382,7 +1544,16 @@ describe('Task-based execution', () => { }, { capabilities: { - tools: {} + tools: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: serverTaskStore } @@ -1410,10 +1581,28 @@ describe('Task-based execution', () => { ] })); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1439,7 +1628,16 @@ describe('Task-based execution', () => { }, { capabilities: { - tools: {} + tools: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: serverTaskStore } @@ -1467,10 +1665,28 @@ describe('Task-based execution', () => { ] })); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1519,7 +1735,16 @@ describe('Task-based execution', () => { }, { capabilities: { - elicitation: {} + elicitation: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: clientTaskStore } @@ -1532,10 +1757,28 @@ describe('Task-based execution', () => { } })); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1581,7 +1824,16 @@ describe('Task-based execution', () => { }, { capabilities: { - elicitation: {} + elicitation: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: clientTaskStore } @@ -1592,10 +1844,28 @@ describe('Task-based execution', () => { content: { username: 'test-user' } })); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1639,7 +1909,16 @@ describe('Task-based execution', () => { }, { capabilities: { - elicitation: {} + elicitation: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: clientTaskStore } @@ -1650,10 +1929,28 @@ describe('Task-based execution', () => { content: { username: 'result-user' } })); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1696,7 +1993,16 @@ describe('Task-based execution', () => { }, { capabilities: { - elicitation: {} + elicitation: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: clientTaskStore } @@ -1707,10 +2013,28 @@ describe('Task-based execution', () => { content: { username: 'list-user' } })); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1765,7 +2089,16 @@ describe('Task-based execution', () => { }, { capabilities: { - tools: {} + tools: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: serverTaskStore } @@ -1793,10 +2126,28 @@ describe('Task-based execution', () => { ] })); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1848,16 +2199,43 @@ describe('Task-based execution', () => { }, { capabilities: { - tools: {} + tools: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: serverTaskStore } ); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1875,16 +2253,43 @@ describe('Task-based execution', () => { }, { capabilities: { - tools: {} + tools: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: serverTaskStore } ); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1902,7 +2307,16 @@ describe('Task-based execution', () => { }, { capabilities: { - elicitation: {} + elicitation: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: clientTaskStore } @@ -1913,10 +2327,28 @@ describe('Task-based execution', () => { content: { username: 'test' } })); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1927,3 +2359,111 @@ describe('Task-based execution', () => { }); }); }); + +test('should respect server task capabilities', async () => { + const serverTaskStore = new InMemoryTaskStore(); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + }, + taskStore: serverTaskStore + } + ); + + server.setRequestHandler(CallToolRequestSchema, async () => ({ + content: [{ type: 'text', text: 'Success!' }] + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } + }, + enforceStrictCapabilities: true + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Server supports task creation for tools/call and task methods + expect(client.getServerCapabilities()).toEqual({ + tools: {}, + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + }); + + // These should work because server supports tasks + const pendingRequest = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { taskId: 'test-task', keepAlive: 60000 } + }); + await expect(pendingRequest.result()).resolves.not.toThrow(); + await expect(client.listTasks()).resolves.not.toThrow(); + await expect(client.getTask({ taskId: 'test-task' })).resolves.not.toThrow(); + + // This should throw because server doesn't support task creation for tools/list + await expect( + client.beginRequest( + { + method: 'tools/list', + params: {} + }, + z.object({ tools: z.array(z.any()) }), + { task: { taskId: 'test-task-2', keepAlive: 60000 } } + ).result() + ).rejects.toThrow('Server does not support task creation for tools/list'); + + serverTaskStore.cleanup(); +}); diff --git a/src/client/index.ts b/src/client/index.ts index da66b1102..857a86e5a 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -294,6 +294,12 @@ export class Client< } protected assertRequestHandlerCapability(method: string): void { + // Task handlers are registered in Protocol constructor before _capabilities is initialized + // Skip capability check for task methods during initialization + if (!this._capabilities) { + return; + } + switch (method) { case 'sampling/createMessage': if (!this._capabilities.sampling) { @@ -313,12 +319,95 @@ export class Client< } break; + case 'tasks/get': + case 'tasks/list': + case 'tasks/result': + case 'tasks/delete': + if (!this._capabilities.tasks) { + throw new Error(`Client does not support tasks capability (required for ${method})`); + } + break; + case 'ping': // No specific capability required for ping break; } } + protected assertTaskCapability(method: string): void { + if (!this._serverCapabilities?.tasks?.requests) { + throw new Error(`Server does not support task creation (required for ${method})`); + } + + const requests = this._serverCapabilities.tasks.requests; + + switch (method) { + case 'tools/call': + if (!requests.tools?.call) { + throw new Error(`Server does not support task creation for tools/call (required for ${method})`); + } + break; + + case 'tools/list': + if (!requests.tools?.list) { + throw new Error(`Server does not support task creation for tools/list (required for ${method})`); + } + break; + + case 'resources/read': + if (!requests.resources?.read) { + throw new Error(`Server does not support task creation for resources/read (required for ${method})`); + } + break; + + case 'resources/list': + if (!requests.resources?.list) { + throw new Error(`Server does not support task creation for resources/list (required for ${method})`); + } + break; + + case 'prompts/get': + if (!requests.prompts?.get) { + throw new Error(`Server does not support task creation for prompts/get (required for ${method})`); + } + break; + + case 'prompts/list': + if (!requests.prompts?.list) { + throw new Error(`Server does not support task creation for prompts/list (required for ${method})`); + } + break; + + case 'tasks/get': + if (!requests.tasks?.get) { + throw new Error(`Server does not support task creation for tasks/get (required for ${method})`); + } + break; + + case 'tasks/list': + if (!requests.tasks?.list) { + throw new Error(`Server does not support task creation for tasks/list (required for ${method})`); + } + break; + + case 'tasks/result': + if (!requests.tasks?.result) { + throw new Error(`Server does not support task creation for tasks/result (required for ${method})`); + } + break; + + case 'tasks/delete': + if (!requests.tasks?.delete) { + throw new Error(`Server does not support task creation for tasks/delete (required for ${method})`); + } + break; + + default: + // Method doesn't support tasks, which is fine - no error + break; + } + } + async ping(options?: RequestOptions) { return this.request({ method: 'ping' }, EmptyResultSchema, options); } diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 0bf13eaf7..a804b2564 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -967,7 +967,16 @@ describe('Task-based execution', () => { }, { capabilities: { - tools: {} + tools: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore } @@ -998,10 +1007,28 @@ describe('Task-based execution', () => { ] })); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1068,10 +1095,28 @@ describe('Task-based execution', () => { ] })); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1092,7 +1137,16 @@ describe('Task-based execution', () => { }, { capabilities: { - tools: {} + tools: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore } @@ -1105,7 +1159,19 @@ describe('Task-based execution', () => { }, { capabilities: { - elicitation: {} + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } } } ); @@ -1231,7 +1297,19 @@ describe('Task-based execution', () => { }, { capabilities: { - elicitation: {} + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: clientTaskStore } @@ -1245,10 +1323,28 @@ describe('Task-based execution', () => { } })); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1295,7 +1391,19 @@ describe('Task-based execution', () => { }, { capabilities: { - elicitation: {} + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: clientTaskStore } @@ -1306,10 +1414,28 @@ describe('Task-based execution', () => { content: { username: 'get-user' } })); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1353,7 +1479,19 @@ describe('Task-based execution', () => { }, { capabilities: { - elicitation: {} + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: clientTaskStore } @@ -1364,10 +1502,28 @@ describe('Task-based execution', () => { content: { username: 'result-user', confirmed: true } })); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1413,7 +1569,19 @@ describe('Task-based execution', () => { }, { capabilities: { - elicitation: {} + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: clientTaskStore } @@ -1424,10 +1592,28 @@ describe('Task-based execution', () => { content: { username: 'list-user' } })); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1482,7 +1668,16 @@ describe('Task-based execution', () => { }, { capabilities: { - tools: {} + tools: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore } @@ -1516,10 +1711,28 @@ describe('Task-based execution', () => { ] })); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1578,16 +1791,43 @@ describe('Task-based execution', () => { }, { capabilities: { - tools: {} + tools: {}, + tasks: { + requests: { + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore } ); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1605,7 +1845,19 @@ describe('Task-based execution', () => { }, { capabilities: { - elicitation: {} + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } }, taskStore: clientTaskStore } @@ -1616,10 +1868,28 @@ describe('Task-based execution', () => { content: { username: 'test' } })); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1630,3 +1900,128 @@ describe('Task-based execution', () => { }); }); }); + +test('should respect client task capabilities', async () => { + const clientTaskStore = new InMemoryTaskStore(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + sampling: {}, + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'test-user' } + })); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + }, + enforceStrictCapabilities: true + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Client supports task creation for elicitation/create and task methods + expect(server.getClientCapabilities()).toEqual({ + sampling: {}, + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: true + }, + tasks: { + get: true, + list: true, + result: true + } + } + } + }); + + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + // These should work because client supports tasks + const pendingRequest = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Test', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + ElicitResultSchema, + { task: { taskId: 'test-task', keepAlive: 60000 } } + ); + await expect(pendingRequest.result()).resolves.not.toThrow(); + await expect(server.listTasks()).resolves.not.toThrow(); + await expect(server.getTask({ taskId: 'test-task' })).resolves.not.toThrow(); + + // This should throw because client doesn't support task creation for sampling/createMessage + await expect( + server.beginRequest( + { + method: 'sampling/createMessage', + params: { + messages: [], + maxTokens: 10 + } + }, + z.object({ + model: z.string(), + role: z.string(), + content: z.any() + }), + { task: { taskId: 'test-task-2', keepAlive: 60000 } } + ).result() + ).rejects.toThrow('Client does not support task creation for sampling/createMessage'); + + clientTaskStore.cleanup(); +}); diff --git a/src/server/index.ts b/src/server/index.ts index d63b4a207..ed543bab5 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -236,6 +236,12 @@ export class Server< } protected assertRequestHandlerCapability(method: string): void { + // Task handlers are registered in Protocol constructor before _capabilities is initialized + // Skip capability check for task methods during initialization + if (!this._capabilities) { + return; + } + switch (method) { case 'sampling/createMessage': if (!this._capabilities.sampling) { @@ -271,6 +277,15 @@ export class Server< } break; + case 'tasks/get': + case 'tasks/list': + case 'tasks/result': + case 'tasks/delete': + if (!this._capabilities.tasks) { + throw new Error(`Server does not support tasks capability (required for ${method})`); + } + break; + case 'ping': case 'initialize': // No specific capability required for these methods @@ -278,6 +293,62 @@ export class Server< } } + protected assertTaskCapability(method: string): void { + if (!this._clientCapabilities?.tasks?.requests) { + throw new Error(`Client does not support task creation (required for ${method})`); + } + + const requests = this._clientCapabilities.tasks.requests; + + switch (method) { + case 'sampling/createMessage': + if (!requests.sampling?.createMessage) { + throw new Error(`Client does not support task creation for sampling/createMessage (required for ${method})`); + } + break; + + case 'elicitation/create': + if (!requests.elicitation?.create) { + throw new Error(`Client does not support task creation for elicitation/create (required for ${method})`); + } + break; + + case 'roots/list': + if (!requests.roots?.list) { + throw new Error(`Client does not support task creation for roots/list (required for ${method})`); + } + break; + + case 'tasks/get': + if (!requests.tasks?.get) { + throw new Error(`Client does not support task creation for tasks/get (required for ${method})`); + } + break; + + case 'tasks/list': + if (!requests.tasks?.list) { + throw new Error(`Client does not support task creation for tasks/list (required for ${method})`); + } + break; + + case 'tasks/result': + if (!requests.tasks?.result) { + throw new Error(`Client does not support task creation for tasks/result (required for ${method})`); + } + break; + + case 'tasks/delete': + if (!requests.tasks?.delete) { + throw new Error(`Client does not support task creation for tasks/delete (required for ${method})`); + } + break; + + default: + // Method doesn't support tasks, which is fine - no error + break; + } + } + private async _oninitialize(request: InitializeRequest): Promise { const requestedVersion = request.params.protocolVersion; diff --git a/src/shared/protocol-transport-handling.test.ts b/src/shared/protocol-transport-handling.test.ts index 375a0ee78..74ce71212 100644 --- a/src/shared/protocol-transport-handling.test.ts +++ b/src/shared/protocol-transport-handling.test.ts @@ -37,6 +37,7 @@ describe('Protocol transport handling bug', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })(); transportA = new MockTransport('A'); diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 664a0ae75..0a3905d03 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -119,6 +119,7 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })(); }); @@ -573,6 +574,7 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced_with_params'] }); await protocol.connect(transport); @@ -594,6 +596,7 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced_with_options'] }); await protocol.connect(transport); @@ -613,6 +616,7 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -637,6 +641,7 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -664,6 +669,7 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -689,6 +695,7 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); // Configure for a different method await protocol.connect(transport); @@ -722,6 +729,7 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -846,6 +854,7 @@ describe('Task-based execution', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })(); }); @@ -1071,6 +1080,7 @@ describe('Task-based execution', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1134,6 +1144,7 @@ describe('Task-based execution', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(responsiveTransport); @@ -1197,6 +1208,7 @@ describe('Task-based execution', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1238,6 +1250,7 @@ describe('Task-based execution', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1291,6 +1304,7 @@ describe('Task-based execution', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1356,6 +1370,7 @@ describe('Task-based execution', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1403,6 +1418,7 @@ describe('Task-based execution', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1438,6 +1454,7 @@ describe('Task-based execution', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1469,6 +1486,7 @@ describe('Task-based execution', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1596,6 +1614,7 @@ describe('Task-based execution', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); const sendSpy = jest.spyOn(serverTransport, 'send'); @@ -1634,6 +1653,7 @@ describe('Task-based execution', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} })({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); const sendSpy = jest.spyOn(serverTransport, 'send'); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 86324f58b..eef66449e 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -577,38 +577,43 @@ export abstract class Protocol { - // If this request asked for task creation, create the task and send notification - if (taskMetadata && this._taskStore) { - const task = await this._taskStore.getTask(taskMetadata.taskId, capturedTransport?.sessionId); - if (task) { - throw new McpError(ErrorCode.InvalidParams, `Task ID already exists: ${taskMetadata.taskId}`); - } + // If this request asked for task creation, check capability first, then create the task and send notification + if (taskMetadata) { + // Check if the request method supports task creation + this.assertTaskCapability(request.method); - await this._taskStore.createTask( - taskMetadata, - request.id, - { - method: request.method, - params: request.params - }, - capturedTransport?.sessionId - ); - this._requestIdToTaskId.set(request.id, taskMetadata.taskId); + if (this._taskStore) { + const task = await this._taskStore.getTask(taskMetadata.taskId, capturedTransport?.sessionId); + if (task) { + throw new McpError(ErrorCode.InvalidParams, `Task ID already exists: ${taskMetadata.taskId}`); + } - // Send task created notification - await this.notification( - { - method: 'notifications/tasks/created', - params: { - _meta: { - [RELATED_TASK_META_KEY]: { - taskId: taskMetadata.taskId + await this._taskStore.createTask( + taskMetadata, + request.id, + { + method: request.method, + params: request.params + }, + capturedTransport?.sessionId + ); + this._requestIdToTaskId.set(request.id, taskMetadata.taskId); + + // Send task created notification + await this.notification( + { + method: 'notifications/tasks/created', + params: { + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: taskMetadata.taskId + } } } - } - } as SendNotificationT, - { relatedRequestId: request.id } - ); + } as SendNotificationT, + { relatedRequestId: request.id } + ); + } } }) .then(async () => { @@ -774,6 +779,13 @@ export abstract class Protocol Date: Wed, 5 Nov 2025 15:51:55 -0800 Subject: [PATCH 23/84] Add taskHint for tool-level signposting --- src/server/mcp.test.ts | 17 +++++++++++------ src/types.ts | 11 ++++++++++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index f3669fa64..a67230243 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -764,7 +764,7 @@ describe('tool()', () => { 'test', 'A tool with everything', { name: z.string() }, - { title: 'Complete Test Tool', readOnlyHint: true, openWorldHint: false }, + { title: 'Complete Test Tool', readOnlyHint: true, openWorldHint: false, taskHint: false }, async ({ name }) => ({ content: [{ type: 'text', text: `Hello, ${name}!` }] }) @@ -778,7 +778,8 @@ describe('tool()', () => { annotations: { title: 'Complete Test Tool', readOnlyHint: true, - openWorldHint: false + openWorldHint: false, + taskHint: false } }, async ({ name }) => ({ @@ -802,7 +803,8 @@ describe('tool()', () => { expect(result.tools[0].annotations).toEqual({ title: 'Complete Test Tool', readOnlyHint: true, - openWorldHint: false + openWorldHint: false, + taskHint: false }); expect(result.tools[1].name).toBe('test (new api)'); expect(result.tools[1].description).toBe('A tool with everything'); @@ -830,7 +832,8 @@ describe('tool()', () => { { title: 'Complete Test Tool with empty params', readOnlyHint: true, - openWorldHint: false + openWorldHint: false, + taskHint: false }, async () => ({ content: [{ type: 'text', text: 'Test response' }] @@ -845,7 +848,8 @@ describe('tool()', () => { annotations: { title: 'Complete Test Tool with empty params', readOnlyHint: true, - openWorldHint: false + openWorldHint: false, + taskHint: false } }, async () => ({ @@ -869,7 +873,8 @@ describe('tool()', () => { expect(result.tools[0].annotations).toEqual({ title: 'Complete Test Tool with empty params', readOnlyHint: true, - openWorldHint: false + openWorldHint: false, + taskHint: false }); expect(result.tools[1].name).toBe('test (new api)'); expect(result.tools[1].description).toBe('A tool with everything but empty params'); diff --git a/src/types.ts b/src/types.ts index 7ed723068..ebee92634 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1176,7 +1176,16 @@ export const ToolAnnotationsSchema = z * * Default: true */ - openWorldHint: z.optional(z.boolean()) + openWorldHint: z.optional(z.boolean()), + + /** + * If true, this tool is expected to support task-augmented execution. + * This allows clients to handle long-running operations through polling + * the task system. + * + * Default: false + */ + taskHint: z.optional(z.boolean()) }) .passthrough(); From 01be32daaab6a03367b0d0cc54075585b64a1b10 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 5 Nov 2025 16:01:09 -0800 Subject: [PATCH 24/84] Only auto-add task ID if server capability is set --- src/client/index.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/client/index.ts b/src/client/index.ts index 857a86e5a..a36c16db1 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -461,7 +461,8 @@ export class Client< // Automatically add task metadata if not provided const optionsWithTask = { ...options, - task: options?.task ?? { taskId: uuidv4() } + // We check the server capabilities in auto-assignment, but assume the caller knows what they're doing if they pass this explicitly + task: options?.task ?? (this._serverCapabilities?.tasks?.requests?.tools?.call ? { taskId: uuidv4() } : undefined) }; return this.beginRequest({ method: 'tools/call', params }, resultSchema, optionsWithTask); } From 0b8ced29f191d098087b0b5a3a107d0b29ebd437 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 5 Nov 2025 20:20:57 -0800 Subject: [PATCH 25/84] Correctly check peer task capabilities on receiving end --- src/client/index.test.ts | 43 ++++++++++ src/client/index.ts | 62 ++++++++++++++ src/examples/server/simpleStreamableHttp.ts | 2 +- src/server/index.test.ts | 12 +++ src/server/index.ts | 80 +++++++++++++++++++ .../protocol-transport-handling.test.ts | 1 + src/shared/protocol.test.ts | 20 +++++ src/shared/protocol.ts | 9 ++- 8 files changed, 227 insertions(+), 2 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 5049bac05..f7d872ff4 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -1373,6 +1373,10 @@ describe('Task-based execution', () => { tools: {}, tasks: { requests: { + tools: { + call: true, + list: true + }, tasks: { get: true, list: true, @@ -1461,6 +1465,10 @@ describe('Task-based execution', () => { tools: {}, tasks: { requests: { + tools: { + call: true, + list: true + }, tasks: { get: true, list: true, @@ -1547,6 +1555,10 @@ describe('Task-based execution', () => { tools: {}, tasks: { requests: { + tools: { + call: true, + list: true + }, tasks: { get: true, list: true, @@ -1631,6 +1643,10 @@ describe('Task-based execution', () => { tools: {}, tasks: { requests: { + tools: { + call: true, + list: true + }, tasks: { get: true, list: true, @@ -1738,6 +1754,9 @@ describe('Task-based execution', () => { elicitation: {}, tasks: { requests: { + elicitation: { + create: true + }, tasks: { get: true, list: true, @@ -1827,6 +1846,9 @@ describe('Task-based execution', () => { elicitation: {}, tasks: { requests: { + elicitation: { + create: true + }, tasks: { get: true, list: true, @@ -1912,6 +1934,9 @@ describe('Task-based execution', () => { elicitation: {}, tasks: { requests: { + elicitation: { + create: true + }, tasks: { get: true, list: true, @@ -1996,6 +2021,9 @@ describe('Task-based execution', () => { elicitation: {}, tasks: { requests: { + elicitation: { + create: true + }, tasks: { get: true, list: true, @@ -2092,6 +2120,10 @@ describe('Task-based execution', () => { tools: {}, tasks: { requests: { + tools: { + call: true, + list: true + }, tasks: { get: true, list: true, @@ -2202,6 +2234,10 @@ describe('Task-based execution', () => { tools: {}, tasks: { requests: { + tools: { + call: true, + list: true + }, tasks: { get: true, list: true, @@ -2256,6 +2292,10 @@ describe('Task-based execution', () => { tools: {}, tasks: { requests: { + tools: { + call: true, + list: true + }, tasks: { get: true, list: true, @@ -2310,6 +2350,9 @@ describe('Task-based execution', () => { elicitation: {}, tasks: { requests: { + elicitation: { + create: true + }, tasks: { get: true, list: true, diff --git a/src/client/index.ts b/src/client/index.ts index a36c16db1..ebe0878ba 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -408,6 +408,68 @@ export class Client< } } + protected assertTaskHandlerCapability(method: string): void { + // Task handlers are registered in Protocol constructor before _capabilities is initialized + // Skip capability check for task methods during initialization + if (!this._capabilities) { + return; + } + + if (!this._capabilities.tasks?.requests) { + throw new Error(`Client does not support task creation (required for ${method})`); + } + + const requests = this._capabilities.tasks.requests; + + switch (method) { + case 'sampling/createMessage': + if (!requests.sampling?.createMessage) { + throw new Error(`Client does not support task creation for sampling/createMessage (required for ${method})`); + } + break; + + case 'elicitation/create': + if (!requests.elicitation?.create) { + throw new Error(`Client does not support task creation for elicitation/create (required for ${method})`); + } + break; + + case 'roots/list': + if (!requests.roots?.list) { + throw new Error(`Client does not support task creation for roots/list (required for ${method})`); + } + break; + + case 'tasks/get': + if (!requests.tasks?.get) { + throw new Error(`Client does not support task creation for tasks/get (required for ${method})`); + } + break; + + case 'tasks/list': + if (!requests.tasks?.list) { + throw new Error(`Client does not support task creation for tasks/list (required for ${method})`); + } + break; + + case 'tasks/result': + if (!requests.tasks?.result) { + throw new Error(`Client does not support task creation for tasks/result (required for ${method})`); + } + break; + + case 'tasks/delete': + if (!requests.tasks?.delete) { + throw new Error(`Client does not support task creation for tasks/delete (required for ${method})`); + } + break; + + default: + // Method doesn't support tasks, which is fine - no error + break; + } + } + async ping(options?: RequestOptions) { return this.request({ method: 'ping' }, EmptyResultSchema, options); } diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index ec73a4f02..dc36c18f0 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -39,7 +39,7 @@ const getServer = () => { websiteUrl: 'https://github.com/modelcontextprotocol/typescript-sdk' }, { - capabilities: { logging: {} }, + capabilities: { logging: {}, tasks: { requests: { tools: { call: true } } } }, taskStore // Enable task support } ); diff --git a/src/server/index.test.ts b/src/server/index.test.ts index a804b2564..2801ca0c1 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -970,6 +970,10 @@ describe('Task-based execution', () => { tools: {}, tasks: { requests: { + tools: { + call: true, + list: true + }, tasks: { get: true, list: true, @@ -1140,6 +1144,10 @@ describe('Task-based execution', () => { tools: {}, tasks: { requests: { + tools: { + call: true, + list: true + }, tasks: { get: true, list: true, @@ -1671,6 +1679,10 @@ describe('Task-based execution', () => { tools: {}, tasks: { requests: { + tools: { + call: true, + list: true + }, tasks: { get: true, list: true, diff --git a/src/server/index.ts b/src/server/index.ts index ed543bab5..b1b6d61ca 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -349,6 +349,86 @@ export class Server< } } + protected assertTaskHandlerCapability(method: string): void { + // Task handlers are registered in Protocol constructor before _capabilities is initialized + // Skip capability check for task methods during initialization + if (!this._capabilities) { + return; + } + + if (!this._capabilities.tasks?.requests) { + throw new Error(`Server does not support task creation (required for ${method})`); + } + + const requests = this._capabilities.tasks.requests; + + switch (method) { + case 'tools/call': + if (!requests.tools?.call) { + throw new Error(`Server does not support task creation for tools/call (required for ${method})`); + } + break; + + case 'tools/list': + if (!requests.tools?.list) { + throw new Error(`Server does not support task creation for tools/list (required for ${method})`); + } + break; + + case 'resources/read': + if (!requests.resources?.read) { + throw new Error(`Server does not support task creation for resources/read (required for ${method})`); + } + break; + + case 'resources/list': + if (!requests.resources?.list) { + throw new Error(`Server does not support task creation for resources/list (required for ${method})`); + } + break; + + case 'prompts/get': + if (!requests.prompts?.get) { + throw new Error(`Server does not support task creation for prompts/get (required for ${method})`); + } + break; + + case 'prompts/list': + if (!requests.prompts?.list) { + throw new Error(`Server does not support task creation for prompts/list (required for ${method})`); + } + break; + + case 'tasks/get': + if (!requests.tasks?.get) { + throw new Error(`Server does not support task creation for tasks/get (required for ${method})`); + } + break; + + case 'tasks/list': + if (!requests.tasks?.list) { + throw new Error(`Server does not support task creation for tasks/list (required for ${method})`); + } + break; + + case 'tasks/result': + if (!requests.tasks?.result) { + throw new Error(`Server does not support task creation for tasks/result (required for ${method})`); + } + break; + + case 'tasks/delete': + if (!requests.tasks?.delete) { + throw new Error(`Server does not support task creation for tasks/delete (required for ${method})`); + } + break; + + default: + // Method doesn't support tasks, which is fine - no error + break; + } + } + private async _oninitialize(request: InitializeRequest): Promise { const requestedVersion = request.params.protocolVersion; diff --git a/src/shared/protocol-transport-handling.test.ts b/src/shared/protocol-transport-handling.test.ts index 74ce71212..d57282d7a 100644 --- a/src/shared/protocol-transport-handling.test.ts +++ b/src/shared/protocol-transport-handling.test.ts @@ -38,6 +38,7 @@ describe('Protocol transport handling bug', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })(); transportA = new MockTransport('A'); diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 0a3905d03..1433d25f1 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -120,6 +120,7 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })(); }); @@ -575,6 +576,7 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced_with_params'] }); await protocol.connect(transport); @@ -597,6 +599,7 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced_with_options'] }); await protocol.connect(transport); @@ -617,6 +620,7 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -642,6 +646,7 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -670,6 +675,7 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -696,6 +702,7 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); // Configure for a different method await protocol.connect(transport); @@ -730,6 +737,7 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -855,6 +863,7 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })(); }); @@ -1081,6 +1090,7 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1145,6 +1155,7 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(responsiveTransport); @@ -1209,6 +1220,7 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1251,6 +1263,7 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1305,6 +1318,7 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1371,6 +1385,7 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1419,6 +1434,7 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1455,6 +1471,7 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1487,6 +1504,7 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1615,6 +1633,7 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); const sendSpy = jest.spyOn(serverTransport, 'send'); @@ -1654,6 +1673,7 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); const sendSpy = jest.spyOn(serverTransport, 'send'); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index eef66449e..7b9ee56bd 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -580,7 +580,7 @@ export abstract class Protocol Date: Mon, 10 Nov 2025 15:16:55 -0800 Subject: [PATCH 26/84] Introduce shim task interface for tools --- src/client/index.test.ts | 18 +- src/examples/server/simpleStreamableHttp.ts | 43 +++- src/examples/shared/inMemoryTaskStore.ts | 4 +- src/server/mcp.ts | 67 +++++- src/shared/protocol.test.ts | 6 +- src/shared/protocol.ts | 237 ++++++++++++++------ src/shared/task.ts | 3 +- src/types.ts | 9 +- 8 files changed, 291 insertions(+), 96 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 543f9fe73..c3a5d53a6 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -2495,14 +2495,16 @@ test('should respect server task capabilities', async () => { // This should throw because server doesn't support task creation for tools/list await expect( - client.beginRequest( - { - method: 'tools/list', - params: {} - }, - z.object({ tools: z.array(z.any()) }), - { task: { taskId: 'test-task-2', keepAlive: 60000 } } - ).result() + client + .beginRequest( + { + method: 'tools/list', + params: {} + }, + z.object({ tools: z.array(z.any()) }), + { task: { taskId: 'test-task-2', keepAlive: 60000 } } + ) + .result() ).rejects.toThrow('Server does not support task creation for tools/list'); serverTaskStore.cleanup(); diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index c1bcd5ad3..1b619a101 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -454,7 +454,7 @@ const getServer = () => { ); // Register a long-running tool that demonstrates task execution - server.registerTool( + server.registerToolTask( 'delay', { title: 'Delay', @@ -463,16 +463,37 @@ const getServer = () => { duration: z.number().describe('Duration in milliseconds').default(5000) } }, - async ({ duration }): Promise => { - await new Promise(resolve => setTimeout(resolve, duration)); - return { - content: [ - { - type: 'text', - text: `Completed ${duration}ms delay` - } - ] - }; + { + async createTask({ duration }, { taskId, taskStore, taskRequestedKeepAlive }) { + // Simulate out-of-band work + (async () => { + await new Promise(resolve => setTimeout(resolve, duration)); + await taskStore.storeTaskResult(taskId, { + content: [ + { + type: 'text', + text: `Completed ${duration}ms delay` + } + ] + }); + })(); + return await taskStore.createTask({ + taskId, + keepAlive: taskRequestedKeepAlive + }); + }, + async getTask(_args, { taskId, taskStore }) { + const task = await taskStore.getTask(taskId); + if (!task) { + throw new Error(`Task ${taskId} not found`); + } + + return task; + }, + async getTaskResult(_args, { taskId, taskStore }) { + const result = await taskStore.getTaskResult(taskId); + return result as CallToolResult; + } } ); diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index 087102f78..f6a218041 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -21,7 +21,7 @@ export class InMemoryTaskStore implements TaskStore { private tasks = new Map(); private cleanupTimers = new Map>(); - async createTask(metadata: TaskMetadata, requestId: RequestId, request: Request, _sessionId?: string): Promise { + async createTask(metadata: TaskMetadata, requestId: RequestId, request: Request, _sessionId?: string): Promise { const taskId = metadata.taskId; if (this.tasks.has(taskId)) { @@ -50,6 +50,8 @@ export class InMemoryTaskStore implements TaskStore { this.cleanupTimers.set(taskId, timer); } + + return task; } async getTask(taskId: string, _sessionId?: string): Promise { diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 6593ea854..76dbd602f 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -31,6 +31,10 @@ import { ServerNotification, ToolAnnotations, LoggingMessageNotification, + CreateTaskResult, + GetTaskResult, + Result, + TASK_META_KEY, CompleteRequestPrompt, CompleteRequestResourceTemplate, assertCompleteRequestPrompt, @@ -38,8 +42,9 @@ import { } from '../types.js'; import { Completable, CompletableDef } from './completable.js'; import { UriTemplate, Variables } from '../shared/uriTemplate.js'; -import { RequestHandlerExtra } from '../shared/protocol.js'; +import { RequestHandlerExtra, RequestTaskStore } from '../shared/protocol.js'; import { Transport } from '../shared/transport.js'; +import { isTerminal } from '../shared/task.js'; /** * High-level MCP server that provides a simpler API for working with resources, tools, and prompts. @@ -836,6 +841,51 @@ export class McpServer { ); } + /** + * Registers a task-based tool with a config object and callback. + */ + registerToolTask( + name: string, + config: { + title?: string; + description?: string; + inputSchema?: InputArgs; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + _meta?: Record; + }, + handler: ToolTaskHandler + ): RegisteredTool { + // TODO: Attach to individual request handlers and remove this wrapper + const cb: ToolCallback = (async (...args) => { + const [inputArgs, extra] = args; + + const taskStore = extra.taskStore; + if (!taskStore) { + throw new Error('Task store is not available'); + } + + const taskMetadata = extra._meta?.[TASK_META_KEY]; + const taskId = taskMetadata?.taskId; + if (!taskId) { + throw new Error('No task ID provided'); + } + + // Internal polling to allow using this interface before internals are hooked up + const taskExtra = { ...extra, taskId, taskStore }; + let task = await handler.createTask(inputArgs, taskExtra); + do { + await new Promise(resolve => setTimeout(resolve, task.pollInterval ?? 5000)); + task = await handler.getTask(inputArgs, taskExtra); + } while (!isTerminal(task.status)); + + const result: CallToolResult = await handler.getTaskResult(inputArgs, taskExtra); + return result; + }) as ToolCallback; + + return this.registerTool(name, { ...config, annotations: { ...config.annotations, taskHint: true } }, cb); + } + /** * Registers a zero-argument prompt `name`, which will run the given function when the client calls it. * @deprecated Use `registerPrompt` instead. @@ -1044,6 +1094,21 @@ export type ToolCallback ? (args: T, extra: RequestHandlerExtra) => CallToolResult | Promise : (extra: RequestHandlerExtra) => CallToolResult | Promise; +export interface TaskRequestHandlerExtra extends RequestHandlerExtra { + taskId: string; + taskStore: RequestTaskStore; +} + +export type TaskRequestHandler = Args extends ZodRawShape + ? (args: z.objectOutputType, extra: TaskRequestHandlerExtra) => SendResultT | Promise + : (extra: TaskRequestHandlerExtra) => SendResultT | Promise; + +export interface ToolTaskHandler { + createTask: TaskRequestHandler; + getTask: TaskRequestHandler; + getTaskResult: TaskRequestHandler; +} + export type RegisteredTool = { title?: string; description?: string; diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index be84d7027..c1918aced 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -37,14 +37,14 @@ function createMockTaskStore(options?: { const tasks: Record = {}; return { createTask: jest.fn((taskMetadata: TaskMetadata, _1: RequestId, _2: Request) => { - tasks[taskMetadata.taskId] = { + const task = (tasks[taskMetadata.taskId] = { taskId: taskMetadata.taskId, status: (taskMetadata.status as Task['status'] | undefined) ?? 'submitted', keepAlive: taskMetadata.keepAlive ?? null, pollInterval: (taskMetadata.pollInterval as Task['pollInterval'] | undefined) ?? 1000 - }; + }); options?.onStatus?.('submitted'); - return Promise.resolve(); + return Promise.resolve(task); }), getTask: jest.fn((taskId: string) => { return Promise.resolve(tasks[taskId] ?? null); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 8bb0c10a2..aa11773de 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -155,6 +155,72 @@ export type NotificationOptions = { // relatedTask is excluded as the SDK controls if this is sent according to if the source is a task. export type TaskRequestOptions = Omit; +/** + * Request-scoped TaskStore interface. + */ +export interface RequestTaskStore { + /** + * Creates a new task with the given metadata and original request. + * + * @param task - The task creation metadata from the request + * @returns The task state including status, keepAlive, pollInterval, and optional error + */ + createTask(task: TaskMetadata): Promise; + + /** + * Gets the current status of a task. + * + * @param taskId - The task identifier + * @returns The task state including status, keepAlive, pollInterval, and optional error + */ + getTask(taskId: string): Promise; + + /** + * Stores the result of a completed task. + * + * @param taskId - The task identifier + * @param result - The result to store + */ + storeTaskResult(taskId: string, result: Result): Promise; + + /** + * Retrieves the stored result of a task. + * + * @param taskId - The task identifier + * @returns The stored result + */ + getTaskResult(taskId: string): Promise; + + /** + * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). + * + * @param taskId - The task identifier + * @param status - The new status + * @param error - Optional error message if status is 'failed' or 'cancelled' + */ + updateTaskStatus( + taskId: string, + status: Status, + error?: ErrorReason + ): Promise; + + /** + * Lists tasks, optionally starting from a pagination cursor. + * + * @param cursor - Optional cursor for pagination + * @returns An object containing the tasks array and an optional nextCursor + */ + listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; + + /** + * Deletes a specific task and its associated data. + * + * @param taskId - The task identifier + * @throws Error if the task doesn't exist or cannot be deleted + */ + deleteTask(taskId: string): Promise; +} + /** * Extra data given to request handlers. */ @@ -189,6 +255,12 @@ export type RequestHandlerExtra< */ requestId: RequestId; + taskId?: string; + + taskStore?: RequestTaskStore; + + taskRequestedKeepAlive?: number; + /** * The original HTTP request. */ @@ -383,9 +455,10 @@ export abstract class Protocol { // If this request had a task, mark it as cancelled in storage const taskId = this._requestIdToTaskId.get(requestId); + const taskStore = this._taskStore ? this.requestTaskStore(undefined, sessionId) : undefined; if (taskId && this._taskStore) { try { - await this._setTaskStatus(taskId, 'cancelled', undefined, sessionId); + await taskStore?.updateTaskStatus(taskId, 'cancelled', undefined); } catch (error) { this._onerror(new Error(`Failed to cancel task ${taskId}: ${error}`)); } @@ -546,6 +619,7 @@ export abstract class Protocol = { signal: abortController.signal, sessionId: capturedTransport?.sessionId, @@ -558,20 +632,23 @@ export abstract class Protocol( - taskId: string, - status: Status, - errorReason?: ErrorReason, - sessionId?: string - ) { - if (!this._taskStore) { - // No task store configured - return; - } - - try { - // Check the current task status to avoid overwriting terminal states - // as a safeguard for when the TaskStore implementation doesn't try - // to avoid this. - const task = await this._taskStore.getTask(taskId, sessionId); - if (!task) { - return; - } - - if (isTerminal(task.status)) { - this._onerror( - new Error(`Failed to update status of task "${taskId}" from terminal status "${task.status}" to "${status}"`) - ); - return; - } - - await this._taskStore.updateTaskStatus(taskId, status, errorReason, sessionId); - } catch (error) { - throw new Error(`Failed to update status of task "${taskId}" to "${status}": ${error}`); - } - } - /** * Gets the current status of a task. */ @@ -1181,6 +1196,90 @@ export abstract class Protocol { + if (!request) { + throw new Error('No request provided'); + } + + const result = await taskStore.createTask( + task, + request.id, + { + method: request.method, + params: request.params + }, + sessionId + ); + + // Send task created notification + await this.notification( + { + method: 'notifications/tasks/created', + params: { + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: task.taskId + } + } + } + } as SendNotificationT, + { relatedRequestId: request.id } + ); + + return result; + }, + getTask: async taskId => { + const task = await taskStore.getTask(taskId, sessionId); + if (!task) { + throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + } + + return task; + }, + storeTaskResult: (taskId, result) => { + return taskStore.storeTaskResult(taskId, result, sessionId); + }, + getTaskResult: taskId => { + return taskStore.getTaskResult(taskId, sessionId); + }, + updateTaskStatus: async (taskId, status, errorReason) => { + try { + // Check the current task status to avoid overwriting terminal states + // as a safeguard for when the TaskStore implementation doesn't try + // to avoid this. + const task = await taskStore.getTask(taskId, sessionId); + if (!task) { + return; + } + + if (isTerminal(task.status)) { + this._onerror( + new Error(`Failed to update status of task "${taskId}" from terminal status "${task.status}" to "${status}"`) + ); + return; + } + + await taskStore.updateTaskStatus(taskId, status, errorReason, sessionId); + } catch (error) { + throw new Error(`Failed to update status of task "${taskId}" to "${status}": ${error}`); + } + }, + listTasks: cursor => { + return taskStore.listTasks(cursor, sessionId); + }, + deleteTask: taskId => { + return taskStore.deleteTask(taskId, sessionId); + } + }; + } } function isPlainObject(value: unknown): value is Record { diff --git a/src/shared/task.ts b/src/shared/task.ts index 3bb09195d..aff01baa6 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -14,8 +14,9 @@ export interface TaskStore { * @param requestId - The JSON-RPC request ID * @param request - The original request that triggered task creation * @param sessionId - Optional session ID for binding the task to a specific session + * @returns The task state including status, keepAlive, pollInterval, and optional error */ - createTask(task: TaskMetadata, requestId: RequestId, request: Request, sessionId?: string): Promise; + createTask(task: TaskMetadata, requestId: RequestId, request: Request, sessionId?: string): Promise; /** * Gets the current status of a task. diff --git a/src/types.ts b/src/types.ts index ad7be97d4..a1e848d90 100644 --- a/src/types.ts +++ b/src/types.ts @@ -664,6 +664,8 @@ export const TaskSchema = z.object({ error: z.optional(z.string()) }); +export const CreateTaskResultSchema = ResultSchema.merge(TaskSchema); + /** * An out-of-band notification used to inform the receiver of a task being created. */ @@ -1732,7 +1734,8 @@ export const ClientResultSchema = z.union([ ElicitResultSchema, ListRootsResultSchema, GetTaskResultSchema, - ListTasksResultSchema + ListTasksResultSchema, + CreateTaskResultSchema ]); /* Server messages */ @@ -1769,7 +1772,8 @@ export const ServerResultSchema = z.union([ CallToolResultSchema, ListToolsResultSchema, GetTaskResultSchema, - ListTasksResultSchema + ListTasksResultSchema, + CreateTaskResultSchema ]); export class McpError extends Error { @@ -1877,6 +1881,7 @@ export type ProgressNotification = Infer; export type Task = Infer; export type TaskMetadata = Infer; export type RelatedTaskMetadata = Infer; +export type CreateTaskResult = Infer; export type TaskCreatedNotification = Infer; export type GetTaskRequest = Infer; export type GetTaskResult = Infer; From b24ea0f41a27eea25996e92f7a32b6c17efe8e8e Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 10 Nov 2025 16:24:28 -0800 Subject: [PATCH 27/84] Remove most capabilities and automatic task management --- src/client/index.test.ts | 443 +++++++++++++----------------------- src/client/index.ts | 84 ------- src/server/index.test.ts | 367 +++++++++++++---------------- src/server/index.ts | 84 ------- src/shared/protocol.test.ts | 157 ++++--------- src/shared/protocol.ts | 39 +--- src/types.ts | 61 +---- 7 files changed, 368 insertions(+), 867 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index c3a5d53a6..9322bd505 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -1310,28 +1310,10 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: true - }, - tasks: { - get: true, - list: true, - result: true - } - } - } - } - } - ); + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1371,13 +1353,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true, - list: true - }, - tasks: { - get: true, - list: true, - result: true + call: true } } } @@ -1386,11 +1362,21 @@ describe('Task-based execution', () => { } ); - server.setRequestHandler(CallToolRequestSchema, async request => { + server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } if (request.params.name === 'test-tool') { - return { + const result = { content: [{ type: 'text', text: 'Tool executed successfully!' }] }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; } throw new Error('Unknown tool'); }); @@ -1408,28 +1394,10 @@ describe('Task-based execution', () => { ] })); - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: true - }, - tasks: { - get: true, - list: true, - result: true - } - } - } - } - } - ); + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1463,13 +1431,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true, - list: true - }, - tasks: { - get: true, - list: true, - result: true + call: true } } } @@ -1478,11 +1440,21 @@ describe('Task-based execution', () => { } ); - server.setRequestHandler(CallToolRequestSchema, async request => { + server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } if (request.params.name === 'test-tool') { - return { + const result = { content: [{ type: 'text', text: 'Success!' }] }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; } throw new Error('Unknown tool'); }); @@ -1500,28 +1472,10 @@ describe('Task-based execution', () => { ] })); - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: true - }, - tasks: { - get: true, - list: true, - result: true - } - } - } - } - } - ); + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1555,11 +1509,6 @@ describe('Task-based execution', () => { tools: { call: true, list: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1568,11 +1517,21 @@ describe('Task-based execution', () => { } ); - server.setRequestHandler(CallToolRequestSchema, async request => { + server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } if (request.params.name === 'test-tool') { - return { + const result = { content: [{ type: 'text', text: 'Result data!' }] }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; } throw new Error('Unknown tool'); }); @@ -1590,28 +1549,10 @@ describe('Task-based execution', () => { ] })); - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: true - }, - tasks: { - get: true, - list: true, - result: true - } - } - } - } - } - ); + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1641,13 +1582,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true, - list: true - }, - tasks: { - get: true, - list: true, - result: true + call: true } } } @@ -1656,11 +1591,21 @@ describe('Task-based execution', () => { } ); - server.setRequestHandler(CallToolRequestSchema, async request => { + server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } if (request.params.name === 'test-tool') { - return { + const result = { content: [{ type: 'text', text: 'Success!' }] }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; } throw new Error('Unknown tool'); }); @@ -1678,28 +1623,10 @@ describe('Task-based execution', () => { ] })); - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: true - }, - tasks: { - get: true, - list: true, - result: true - } - } - } - } - } - ); + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1753,11 +1680,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1766,12 +1688,22 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async () => ({ - action: 'accept', - content: { - username: 'test-user' + client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); } - })); + const result = { + action: 'accept', + content: { username: 'list-user' } + }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; + }); const server = new Server( { @@ -1784,11 +1716,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1845,11 +1772,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1858,10 +1780,22 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async () => ({ - action: 'accept', - content: { username: 'test-user' } - })); + client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } + const result = { + action: 'accept', + content: { username: 'list-user' } + }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; + }); const server = new Server( { @@ -1874,11 +1808,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1933,11 +1862,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1946,10 +1870,22 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async () => ({ - action: 'accept', - content: { username: 'result-user' } - })); + client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } + const result = { + action: 'accept', + content: { username: 'result-user' } + }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; + }); const server = new Server( { @@ -1962,11 +1898,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -2020,11 +1951,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -2033,10 +1959,22 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async () => ({ - action: 'accept', - content: { username: 'list-user' } - })); + client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } + const result = { + action: 'accept', + content: { username: 'list-user' } + }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; + }); const server = new Server( { @@ -2049,11 +1987,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -2118,13 +2051,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true, - list: true - }, - tasks: { - get: true, - list: true, - result: true + call: true } } } @@ -2133,11 +2060,21 @@ describe('Task-based execution', () => { } ); - server.setRequestHandler(CallToolRequestSchema, async request => { + server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } if (request.params.name === 'test-tool') { - return { + const result = { content: [{ type: 'text', text: `Result for ${request.params.arguments?.id || 'unknown'}` }] }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; } throw new Error('Unknown tool'); }); @@ -2166,11 +2103,6 @@ describe('Task-based execution', () => { requests: { tools: { call: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -2232,13 +2164,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true, - list: true - }, - tasks: { - get: true, - list: true, - result: true + call: true } } } @@ -2258,11 +2184,6 @@ describe('Task-based execution', () => { requests: { tools: { call: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -2290,13 +2211,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true, - list: true - }, - tasks: { - get: true, - list: true, - result: true + call: true } } } @@ -2316,11 +2231,6 @@ describe('Task-based execution', () => { requests: { tools: { call: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -2349,11 +2259,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -2378,11 +2283,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -2414,11 +2314,6 @@ test('should respect server task capabilities', async () => { requests: { tools: { call: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -2450,17 +2345,6 @@ test('should respect server task capabilities', async () => { version: '1.0.0' }, { - capabilities: { - tasks: { - requests: { - tasks: { - get: true, - list: true, - result: true - } - } - } - }, enforceStrictCapabilities: true } ); @@ -2468,18 +2352,13 @@ test('should respect server task capabilities', async () => { const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // Server supports task creation for tools/call and task methods + // Server supports task creation for tools/call expect(client.getServerCapabilities()).toEqual({ tools: {}, tasks: { requests: { tools: { call: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -2491,21 +2370,17 @@ test('should respect server task capabilities', async () => { }); await expect(pendingRequest.result()).resolves.not.toThrow(); await expect(client.listTasks()).resolves.not.toThrow(); - await expect(client.getTask({ taskId: 'test-task' })).resolves.not.toThrow(); - // This should throw because server doesn't support task creation for tools/list + // tools/list doesn't support task creation, but it shouldn't throw - it should just ignore the task metadata await expect( - client - .beginRequest( - { - method: 'tools/list', - params: {} - }, - z.object({ tools: z.array(z.any()) }), - { task: { taskId: 'test-task-2', keepAlive: 60000 } } - ) - .result() - ).rejects.toThrow('Server does not support task creation for tools/list'); + client.request( + { + method: 'tools/list', + params: {} + }, + z.object({ tools: z.array(z.any()) }) + ) + ).resolves.not.toThrow(); serverTaskStore.cleanup(); }); diff --git a/src/client/index.ts b/src/client/index.ts index ebe0878ba..e2d25ab7b 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -348,60 +348,6 @@ export class Client< } break; - case 'tools/list': - if (!requests.tools?.list) { - throw new Error(`Server does not support task creation for tools/list (required for ${method})`); - } - break; - - case 'resources/read': - if (!requests.resources?.read) { - throw new Error(`Server does not support task creation for resources/read (required for ${method})`); - } - break; - - case 'resources/list': - if (!requests.resources?.list) { - throw new Error(`Server does not support task creation for resources/list (required for ${method})`); - } - break; - - case 'prompts/get': - if (!requests.prompts?.get) { - throw new Error(`Server does not support task creation for prompts/get (required for ${method})`); - } - break; - - case 'prompts/list': - if (!requests.prompts?.list) { - throw new Error(`Server does not support task creation for prompts/list (required for ${method})`); - } - break; - - case 'tasks/get': - if (!requests.tasks?.get) { - throw new Error(`Server does not support task creation for tasks/get (required for ${method})`); - } - break; - - case 'tasks/list': - if (!requests.tasks?.list) { - throw new Error(`Server does not support task creation for tasks/list (required for ${method})`); - } - break; - - case 'tasks/result': - if (!requests.tasks?.result) { - throw new Error(`Server does not support task creation for tasks/result (required for ${method})`); - } - break; - - case 'tasks/delete': - if (!requests.tasks?.delete) { - throw new Error(`Server does not support task creation for tasks/delete (required for ${method})`); - } - break; - default: // Method doesn't support tasks, which is fine - no error break; @@ -434,36 +380,6 @@ export class Client< } break; - case 'roots/list': - if (!requests.roots?.list) { - throw new Error(`Client does not support task creation for roots/list (required for ${method})`); - } - break; - - case 'tasks/get': - if (!requests.tasks?.get) { - throw new Error(`Client does not support task creation for tasks/get (required for ${method})`); - } - break; - - case 'tasks/list': - if (!requests.tasks?.list) { - throw new Error(`Client does not support task creation for tasks/list (required for ${method})`); - } - break; - - case 'tasks/result': - if (!requests.tasks?.result) { - throw new Error(`Client does not support task creation for tasks/result (required for ${method})`); - } - break; - - case 'tasks/delete': - if (!requests.tasks?.delete) { - throw new Error(`Client does not support task creation for tasks/delete (required for ${method})`); - } - break; - default: // Method doesn't support tasks, which is fine - no error break; diff --git a/src/server/index.test.ts b/src/server/index.test.ts index e5c46370f..85652d363 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -967,13 +967,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true, - list: true - }, - tasks: { - get: true, - list: true, - result: true + call: true } } } @@ -983,13 +977,23 @@ describe('Task-based execution', () => { ); // Set up a tool handler that simulates some work - server.setRequestHandler(CallToolRequestSchema, async request => { + server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } if (request.params.name === 'test-tool') { // Simulate some async work await new Promise(resolve => setTimeout(resolve, 10)); - return { + const result = { content: [{ type: 'text', text: 'Tool executed successfully!' }] }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; } throw new Error('Unknown tool'); }); @@ -1018,11 +1022,6 @@ describe('Task-based execution', () => { requests: { tools: { call: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1106,11 +1105,6 @@ describe('Task-based execution', () => { requests: { tools: { call: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1141,13 +1135,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true, - list: true - }, - tasks: { - get: true, - list: true, - result: true + call: true } } } @@ -1168,11 +1156,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1184,7 +1167,14 @@ describe('Task-based execution', () => { let capturedElicitRequest: z.infer | null = null; // Set up client elicitation handler - client.setRequestHandler(ElicitRequestSchema, async request => { + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } + // Capture the request to verify metadata later capturedElicitRequest = request; @@ -1198,9 +1188,15 @@ describe('Task-based execution', () => { // Set up server tool that makes a nested elicitation request server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } + if (request.params.name === 'collect-info') { // During tool execution, make a nested request to the client using extra.sendRequest - // This should AUTOMATICALLY attach the related-task metadata const elicitResult = await extra.sendRequest( { method: 'elicitation/create', @@ -1221,7 +1217,7 @@ describe('Task-based execution', () => { }) ); - return { + const result = { content: [ { type: 'text', @@ -1229,6 +1225,10 @@ describe('Task-based execution', () => { } ] }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; } throw new Error('Unknown tool'); }); @@ -1306,11 +1306,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1319,36 +1314,27 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async () => ({ - action: 'accept', - content: { - username: 'server-test-user', - confirmed: true + client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); } - })); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { - create: true - }, - tasks: { - get: true, - list: true, - result: true - } - } - } - } + const result = { + action: 'accept', + content: { username: 'server-test-user', confirmed: true } + }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); } - ); + return result; + }); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1400,11 +1386,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1413,33 +1394,27 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async () => ({ - action: 'accept', - content: { username: 'get-user' } - })); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { - create: true - }, - tasks: { - get: true, - list: true, - result: true - } - } - } - } + client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); } - ); + const result = { + action: 'accept', + content: { username: 'list-user' } + }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; + }); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1488,11 +1463,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1501,33 +1471,27 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async () => ({ - action: 'accept', - content: { username: 'result-user', confirmed: true } - })); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { - create: true - }, - tasks: { - get: true, - list: true, - result: true - } - } - } - } + client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); } - ); + const result = { + action: 'accept', + content: { username: 'result-user', confirmed: true } + }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; + }); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1578,11 +1542,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1591,10 +1550,22 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async () => ({ - action: 'accept', - content: { username: 'list-user' } - })); + client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } + const result = { + action: 'accept', + content: { username: 'list-user' } + }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; + }); const server = new Server( { @@ -1607,11 +1578,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1676,13 +1642,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true, - list: true - }, - tasks: { - get: true, - list: true, - result: true + call: true } } } @@ -1692,13 +1652,23 @@ describe('Task-based execution', () => { ); // Set up a tool handler with variable delay - server.setRequestHandler(CallToolRequestSchema, async request => { + server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } if (request.params.name === 'async-tool') { const delay = (request.params.arguments?.delay as number) || 10; await new Promise(resolve => setTimeout(resolve, delay)); - return { + const result = { content: [{ type: 'text', text: `Completed task ${request.params.arguments?.taskNum || 'unknown'}` }] }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; } throw new Error('Unknown tool'); }); @@ -1730,11 +1700,6 @@ describe('Task-based execution', () => { requests: { tools: { call: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1802,10 +1767,8 @@ describe('Task-based execution', () => { tools: {}, tasks: { requests: { - tasks: { - get: true, - list: true, - result: true + tools: { + call: true } } } @@ -1825,11 +1788,6 @@ describe('Task-based execution', () => { requests: { tools: { call: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1858,11 +1816,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1887,11 +1840,6 @@ describe('Task-based execution', () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1925,11 +1873,6 @@ test('should respect client task capabilities', async () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1938,10 +1881,22 @@ test('should respect client task capabilities', async () => { } ); - client.setRequestHandler(ElicitRequestSchema, async () => ({ - action: 'accept', - content: { username: 'test-user' } - })); + client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { + if (extra.taskId) { + await extra.taskStore?.createTask({ + taskId: extra.taskId, + keepAlive: extra.taskRequestedKeepAlive + }); + } + const result = { + action: 'accept', + content: { username: 'test-user' } + }; + if (extra.taskId) { + await extra.taskStore?.storeTaskResult(extra.taskId, result); + } + return result; + }); const server = new Server( { @@ -1954,11 +1909,6 @@ test('should respect client task capabilities', async () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -1978,11 +1928,6 @@ test('should respect client task capabilities', async () => { requests: { elicitation: { create: true - }, - tasks: { - get: true, - list: true, - result: true } } } @@ -2014,21 +1959,23 @@ test('should respect client task capabilities', async () => { // This should throw because client doesn't support task creation for sampling/createMessage await expect( - server.beginRequest( - { - method: 'sampling/createMessage', - params: { - messages: [], - maxTokens: 10 - } - }, - z.object({ - model: z.string(), - role: z.string(), - content: z.any() - }), - { task: { taskId: 'test-task-2', keepAlive: 60000 } } - ).result() + server + .beginRequest( + { + method: 'sampling/createMessage', + params: { + messages: [], + maxTokens: 10 + } + }, + z.object({ + model: z.string(), + role: z.string(), + content: z.any() + }), + { task: { taskId: 'test-task-2', keepAlive: 60000 } } + ) + .result() ).rejects.toThrow('Client does not support task creation for sampling/createMessage'); clientTaskStore.cleanup(); diff --git a/src/server/index.ts b/src/server/index.ts index 1faf2fcf9..6a0d5c8a9 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -313,36 +313,6 @@ export class Server< } break; - case 'roots/list': - if (!requests.roots?.list) { - throw new Error(`Client does not support task creation for roots/list (required for ${method})`); - } - break; - - case 'tasks/get': - if (!requests.tasks?.get) { - throw new Error(`Client does not support task creation for tasks/get (required for ${method})`); - } - break; - - case 'tasks/list': - if (!requests.tasks?.list) { - throw new Error(`Client does not support task creation for tasks/list (required for ${method})`); - } - break; - - case 'tasks/result': - if (!requests.tasks?.result) { - throw new Error(`Client does not support task creation for tasks/result (required for ${method})`); - } - break; - - case 'tasks/delete': - if (!requests.tasks?.delete) { - throw new Error(`Client does not support task creation for tasks/delete (required for ${method})`); - } - break; - default: // Method doesn't support tasks, which is fine - no error break; @@ -369,60 +339,6 @@ export class Server< } break; - case 'tools/list': - if (!requests.tools?.list) { - throw new Error(`Server does not support task creation for tools/list (required for ${method})`); - } - break; - - case 'resources/read': - if (!requests.resources?.read) { - throw new Error(`Server does not support task creation for resources/read (required for ${method})`); - } - break; - - case 'resources/list': - if (!requests.resources?.list) { - throw new Error(`Server does not support task creation for resources/list (required for ${method})`); - } - break; - - case 'prompts/get': - if (!requests.prompts?.get) { - throw new Error(`Server does not support task creation for prompts/get (required for ${method})`); - } - break; - - case 'prompts/list': - if (!requests.prompts?.list) { - throw new Error(`Server does not support task creation for prompts/list (required for ${method})`); - } - break; - - case 'tasks/get': - if (!requests.tasks?.get) { - throw new Error(`Server does not support task creation for tasks/get (required for ${method})`); - } - break; - - case 'tasks/list': - if (!requests.tasks?.list) { - throw new Error(`Server does not support task creation for tasks/list (required for ${method})`); - } - break; - - case 'tasks/result': - if (!requests.tasks?.result) { - throw new Error(`Server does not support task creation for tasks/result (required for ${method})`); - } - break; - - case 'tasks/delete': - if (!requests.tasks?.delete) { - throw new Error(`Server does not support task creation for tasks/delete (required for ${method})`); - } - break; - default: // Method doesn't support tasks, which is fine - no error break; diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index c1918aced..c4803e443 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -1,5 +1,6 @@ import { ZodType, z } from 'zod'; import { + CallToolRequestSchema, ClientCapabilities, ErrorCode, McpError, @@ -1097,15 +1098,29 @@ describe('Task-based execution', () => { await protocol.connect(transport); - protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async () => ({ - result: 'success' - })); + protocol.setRequestHandler(CallToolRequestSchema, async request => { + await mockTaskStore.createTask( + { + taskId: 'test-task', + keepAlive: 60000 + }, + 1, + request, + undefined + ); + await mockTaskStore.updateTaskStatus('test-task', 'working', undefined, undefined); + return { + result: 'success' + }; + }); transport.onmessage?.({ jsonrpc: '2.0', id: 1, - method: 'test/method', + method: 'tools/call', params: { + name: 'test', + arguments: {}, _meta: { [TASK_META_KEY]: { taskId: 'test-task', @@ -1121,7 +1136,7 @@ describe('Task-based execution', () => { { taskId: 'test-task', keepAlive: 60000 }, 1, { - method: 'test/method', + method: 'tools/call', params: expect.any(Object) }, undefined @@ -1168,7 +1183,16 @@ describe('Task-based execution', () => { return Promise.resolve(); }); - protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async (_request, extra) => { + protocol.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + await mockTaskStore.createTask( + { + taskId: 'test-task', + keepAlive: 60000 + }, + 1, + request + ); + await mockTaskStore.updateTaskStatus('test-task', 'working', undefined, undefined); await extra.sendRequest({ method: 'nested/request', params: {} }, z.object({ nested: z.string() })); return { result: 'success' }; }); @@ -1176,8 +1200,10 @@ describe('Task-based execution', () => { responsiveTransport.onmessage?.({ jsonrpc: '2.0', id: 1, - method: 'test/method', + method: 'tools/call', params: { + name: 'test', + arguments: {}, _meta: { [TASK_META_KEY]: { taskId: 'test-task', @@ -1227,55 +1253,17 @@ describe('Task-based execution', () => { await protocol.connect(transport); - protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async () => ({ - result: 'success' - })); - - transport.onmessage?.({ - jsonrpc: '2.0', - id: 1, - method: 'test/method', - params: { - _meta: { - [TASK_META_KEY]: { - taskId: 'test-task', - keepAlive: 60000 - } - } - } - }); - - await completeProcessed.waitForLatch(); - - expect(mockTaskStore.storeTaskResult).toHaveBeenCalledWith('test-task', { result: 'success' }, undefined); - }); - - it('should mark task as cancelled when notifications/cancelled is received', async () => { - const cancelProcessed = createLatch(); - const mockTaskStore = createMockTaskStore({ - onStatus: status => { - if (status === 'cancelled') { - cancelProcessed.releaseLatch(); - } - } - }); - - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); - - await protocol.connect(transport); - - const requestInProgress = createLatch(); - const cancelSent = createLatch(); - - protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async () => { - requestInProgress.releaseLatch(); - await cancelSent.waitForLatch(); + protocol.setRequestHandler(CallToolRequestSchema, async request => { + await mockTaskStore.createTask( + { + taskId: 'test-task', + keepAlive: 60000 + }, + 1, + request + ); + await mockTaskStore.updateTaskStatus('test-task', 'working', undefined, undefined); + await mockTaskStore.storeTaskResult('test-task', { result: 'success' }, undefined); return { result: 'success' }; @@ -1284,56 +1272,10 @@ describe('Task-based execution', () => { transport.onmessage?.({ jsonrpc: '2.0', id: 1, - method: 'test/method', - params: { - _meta: { - [TASK_META_KEY]: { - taskId: 'test-task', - keepAlive: 60000 - } - } - } - }); - - transport.onmessage?.({ - jsonrpc: '2.0', - method: 'notifications/cancelled', - params: { - requestId: 1, - reason: 'User cancelled' - } - }); - - await requestInProgress.waitForLatch(); - cancelSent.releaseLatch(); - await cancelProcessed.waitForLatch(); - - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'cancelled', undefined, undefined); - }); - - it('should mark task as failed when updateTaskStatus to working fails', async () => { - const mockTaskStore = createMockTaskStore(); - mockTaskStore.updateTaskStatus.mockRejectedValueOnce(new Error('Failed to update status')).mockResolvedValue(undefined); - - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); - - await protocol.connect(transport); - - protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async () => ({ - result: 'success' - })); - - transport.onmessage?.({ - jsonrpc: '2.0', - id: 1, - method: 'test/method', + method: 'tools/call', params: { + name: 'test', + arguments: {}, _meta: { [TASK_META_KEY]: { taskId: 'test-task', @@ -1343,10 +1285,9 @@ describe('Task-based execution', () => { } }); - await new Promise(resolve => setTimeout(resolve, 50)); + await completeProcessed.waitForLatch(); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working', undefined, undefined); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'failed', 'Failed to mark task as working', undefined); + expect(mockTaskStore.storeTaskResult).toHaveBeenCalledWith('test-task', { result: 'success' }, undefined); }); }); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index aa11773de..91a16c5b9 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -653,34 +653,11 @@ export abstract class Protocol { - // If this request asked for task creation, check capability first, then create the task and send notification + .then(() => { + // If this request asked for task creation, check capability first if (taskMetadata) { // Check if the request method supports task creation this.assertTaskHandlerCapability(request.method); - - if (this._taskStore) { - const task = await this._taskStore.getTask(taskMetadata.taskId, capturedTransport?.sessionId); - if (task) { - throw new McpError(ErrorCode.InvalidParams, `Task ID already exists: ${taskMetadata.taskId}`); - } - - this._requestIdToTaskId.set(request.id, taskMetadata.taskId); - } - } - }) - .then(async () => { - // If this request had a task, mark it as working - if (taskMetadata) { - try { - await taskStore?.updateTaskStatus(taskMetadata.taskId, 'working', undefined); - } catch { - try { - await taskStore?.updateTaskStatus(taskMetadata.taskId, 'failed', 'Failed to mark task as working'); - } catch (error) { - throw new McpError(ErrorCode.InternalError, `Failed to mark task as working: ${error}`); - } - } } }) .then(() => handler(request, fullExtra)) @@ -692,18 +669,6 @@ export abstract class Protocol Date: Mon, 10 Nov 2025 16:32:04 -0800 Subject: [PATCH 28/84] Support customizing task pollInterval on server --- src/examples/server/simpleStreamableHttp.ts | 3 ++- src/examples/shared/inMemoryTaskStore.ts | 2 +- src/types.ts | 7 ++++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 1b619a101..2c5066eca 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -479,7 +479,8 @@ const getServer = () => { })(); return await taskStore.createTask({ taskId, - keepAlive: taskRequestedKeepAlive + keepAlive: taskRequestedKeepAlive, + pollInterval: 100 }); }, async getTask(_args, { taskId, taskStore }) { diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index f6a218041..20cbe178b 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -32,7 +32,7 @@ export class InMemoryTaskStore implements TaskStore { taskId, status: 'submitted', keepAlive: metadata.keepAlive ?? null, - pollInterval: 500 + pollInterval: metadata.pollInterval ?? 500 }; this.tasks.set(taskId, { diff --git a/src/types.ts b/src/types.ts index a250408bd..35259c542 100644 --- a/src/types.ts +++ b/src/types.ts @@ -44,7 +44,12 @@ export const TaskMetadataSchema = z /** * Time in milliseconds to ask to keep task results available after completion. Only used with taskId. */ - keepAlive: z.number().optional() + keepAlive: z.number().optional(), + + /** + * Time in milliseconds to wait between task status requests. Only used with taskId. + */ + pollInterval: z.optional(z.number()) }) /** * Passthrough required here because we want to allow any additional fields to be added to the request meta. From 2382df4b2e438337f0d6362b9b9d711a271d4953 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 10 Nov 2025 16:41:10 -0800 Subject: [PATCH 29/84] Remove submitted and unknown statuses --- src/examples/shared/inMemoryTaskStore.test.ts | 8 +++----- src/examples/shared/inMemoryTaskStore.ts | 2 +- src/shared/protocol.test.ts | 7 +++---- src/shared/task.ts | 4 ++-- src/types.ts | 2 +- 5 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts index f6a47a1ba..c0540783b 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -13,7 +13,7 @@ describe('InMemoryTaskStore', () => { }); describe('createTask', () => { - it('should create a new task with submitted status', async () => { + it('should create a new task with working status', async () => { const metadata: TaskMetadata = { taskId: 'task-1', keepAlive: 60000 @@ -28,7 +28,7 @@ describe('InMemoryTaskStore', () => { const task = await store.getTask('task-1'); expect(task).toBeDefined(); expect(task?.taskId).toBe('task-1'); - expect(task?.status).toBe('submitted'); + expect(task?.status).toBe('working'); expect(task?.keepAlive).toBe(60000); expect(task?.pollInterval).toBe(500); }); @@ -99,9 +99,7 @@ describe('InMemoryTaskStore', () => { }); }); - it('should update task status from submitted to working', async () => { - await store.updateTaskStatus('status-test', 'working'); - + it('should keep task status as working', async () => { const task = await store.getTask('status-test'); expect(task?.status).toBe('working'); }); diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index 20cbe178b..cb6e49b9e 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -30,7 +30,7 @@ export class InMemoryTaskStore implements TaskStore { const task: Task = { taskId, - status: 'submitted', + status: 'working', keepAlive: metadata.keepAlive ?? null, pollInterval: metadata.pollInterval ?? 500 }; diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index c4803e443..ebf6b9e10 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -40,11 +40,11 @@ function createMockTaskStore(options?: { createTask: jest.fn((taskMetadata: TaskMetadata, _1: RequestId, _2: Request) => { const task = (tasks[taskMetadata.taskId] = { taskId: taskMetadata.taskId, - status: (taskMetadata.status as Task['status'] | undefined) ?? 'submitted', + status: (taskMetadata.status as Task['status'] | undefined) ?? 'working', keepAlive: taskMetadata.keepAlive ?? null, pollInterval: (taskMetadata.pollInterval as Task['pollInterval'] | undefined) ?? 1000 }); - options?.onStatus?.('submitted'); + options?.onStatus?.('working'); return Promise.resolve(task); }), getTask: jest.fn((taskId: string) => { @@ -1362,7 +1362,6 @@ describe('Task-based execution', () => { await mockTaskStore.createTask( { taskId: 'task-3', - status: 'submitted', pollInterval: 500 }, 1, @@ -1398,7 +1397,7 @@ describe('Task-based execution', () => { const sentMessage = sendSpy.mock.calls[0][0]; expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(2); - expect(sentMessage.result.tasks).toEqual([{ taskId: 'task-3', status: 'submitted', keepAlive: null, pollInterval: 500 }]); + expect(sentMessage.result.tasks).toEqual([{ taskId: 'task-3', status: 'working', keepAlive: null, pollInterval: 500 }]); expect(sentMessage.result.nextCursor).toBeUndefined(); expect(sentMessage.result._meta).toEqual({}); }); diff --git a/src/shared/task.ts b/src/shared/task.ts index aff01baa6..d076dd851 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -79,8 +79,8 @@ export interface TaskStore { * Terminal states are those where the task has finished and will not change. * * @param status - The task status to check - * @returns True if the status is terminal (completed, failed, cancelled, or unknown) + * @returns True if the status is terminal (completed, failed, or cancelled) */ export function isTerminal(status: Task['status']): boolean { - return status === 'completed' || status === 'failed' || status === 'cancelled' || status === 'unknown'; + return status === 'completed' || status === 'failed' || status === 'cancelled'; } diff --git a/src/types.ts b/src/types.ts index 35259c542..61dbb3b57 100644 --- a/src/types.ts +++ b/src/types.ts @@ -604,7 +604,7 @@ export const PaginatedResultSchema = ResultSchema.extend({ */ export const TaskSchema = z.object({ taskId: z.string(), - status: z.enum(['submitted', 'working', 'input_required', 'completed', 'failed', 'cancelled', 'unknown']), + status: z.enum(['working', 'input_required', 'completed', 'failed', 'cancelled']), keepAlive: z.union([z.number(), z.null()]), pollInterval: z.optional(z.number()), error: z.optional(z.string()) From 388e603e5018aa345971c44a904848f1d1950627 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 11 Nov 2025 15:14:03 -0800 Subject: [PATCH 30/84] Fix spec types test compatibility after upstream merge Added FixSpecServerCapabilities and FixSpecInitializeResult type helpers to handle index signature requirements for ServerCapabilities that differ between the spec types (no index signature) and SDK types (passthrough with index signature). --- src/spec.types.test.ts | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spec.types.test.ts b/src/spec.types.test.ts index 2417e6b1d..a387366d4 100644 --- a/src/spec.types.test.ts +++ b/src/spec.types.test.ts @@ -67,6 +67,11 @@ type FixSpecClientCapabilities = T extends { elicitation?: object } ? Omit & { elicitation?: Record } : T; +// Targeted fix: in spec, ServerCapabilities needs index signature to match SDK's passthrough +type FixSpecServerCapabilities = T & { [x: string]: unknown }; + +type FixSpecInitializeResult = T extends { capabilities: infer C } ? T & { capabilities: FixSpecServerCapabilities } : T; + type FixSpecInitializeRequestParams = T extends { capabilities: infer C } ? Omit & { capabilities: FixSpecClientCapabilities } : T; @@ -531,7 +536,7 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - InitializeResult: (sdk: SDKTypes.InitializeResult, spec: SpecTypes.InitializeResult) => { + InitializeResult: (sdk: SDKTypes.InitializeResult, spec: FixSpecInitializeResult) => { sdk = spec; spec = sdk; }, @@ -539,7 +544,7 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - ServerCapabilities: (sdk: SDKTypes.ServerCapabilities, spec: SpecTypes.ServerCapabilities) => { + ServerCapabilities: (sdk: SDKTypes.ServerCapabilities, spec: FixSpecServerCapabilities) => { sdk = spec; spec = sdk; }, From 8d06d4fb5a347adc9a83d4359d06e0b69805b27d Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 17 Nov 2025 12:31:19 -0800 Subject: [PATCH 31/84] chore: commit new lockfile --- package-lock.json | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/package-lock.json b/package-lock.json index a66b59700..3189cd2c8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -815,26 +815,6 @@ "dev": true, "license": "MIT" }, - "node_modules/@mswjs/interceptors": { - "version": "0.40.0", - "resolved": "https://registry.npmjs.org/@mswjs/interceptors/-/interceptors-0.40.0.tgz", - "integrity": "sha512-EFd6cVbHsgLa6wa4RljGj6Wk75qoHxUSyc5asLyyPSyuhIcdS2Q3Phw6ImS1q+CkALthJRShiYfKANcQMuMqsQ==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "@open-draft/deferred-promise": "^2.2.0", - "@open-draft/logger": "^0.3.0", - "@open-draft/until": "^2.0.0", - "is-node-process": "^1.2.0", - "outvariant": "^1.4.3", - "strict-event-emitter": "^0.5.1" - }, - "engines": { - "node": ">=18" - } - }, "node_modules/@lukeed/csprng": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/@lukeed/csprng/-/csprng-1.1.0.tgz", @@ -856,6 +836,26 @@ "node": ">=8" } }, + "node_modules/@mswjs/interceptors": { + "version": "0.40.0", + "resolved": "https://registry.npmjs.org/@mswjs/interceptors/-/interceptors-0.40.0.tgz", + "integrity": "sha512-EFd6cVbHsgLa6wa4RljGj6Wk75qoHxUSyc5asLyyPSyuhIcdS2Q3Phw6ImS1q+CkALthJRShiYfKANcQMuMqsQ==", + "dev": true, + "license": "MIT", + "optional": true, + "peer": true, + "dependencies": { + "@open-draft/deferred-promise": "^2.2.0", + "@open-draft/logger": "^0.3.0", + "@open-draft/until": "^2.0.0", + "is-node-process": "^1.2.0", + "outvariant": "^1.4.3", + "strict-event-emitter": "^0.5.1" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/@noble/hashes": { "version": "1.8.0", "resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.8.0.tgz", From 5975661cfa314556a93a63e8aae2c3d8501a250d Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 18 Nov 2025 12:51:48 -0800 Subject: [PATCH 32/84] Introduce CreateTaskResult to request flow in tools --- src/client/index.test.ts | 377 ++++++++++++------ src/client/index.ts | 6 +- src/examples/client/simpleStreamableHttp.ts | 4 +- src/examples/server/simpleStreamableHttp.ts | 28 +- src/examples/shared/inMemoryTaskStore.test.ts | 65 +-- src/examples/shared/inMemoryTaskStore.ts | 29 +- src/server/index.test.ts | 331 +++++++++------ src/server/mcp.test.ts | 12 +- src/server/mcp.ts | 150 ++++--- src/shared/protocol.test.ts | 325 ++++----------- src/shared/protocol.ts | 179 ++------- src/shared/request.ts | 22 +- src/shared/task.ts | 10 +- src/types.ts | 123 ++++-- 14 files changed, 855 insertions(+), 806 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 835f9f788..71d5a36d5 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -840,11 +840,11 @@ describe('outputSchema validation', () => { tasks: { requests: { tools: { - call: true + call: {} }, tasks: { get: true, - list: true, + list: {}, result: true } } @@ -933,11 +933,11 @@ describe('outputSchema validation', () => { tasks: { requests: { tools: { - call: true + call: {} }, tasks: { get: true, - list: true, + list: {}, result: true } } @@ -1023,11 +1023,11 @@ describe('outputSchema validation', () => { tasks: { requests: { tools: { - call: true + call: {} }, tasks: { get: true, - list: true, + list: {}, result: true } } @@ -1109,11 +1109,11 @@ describe('outputSchema validation', () => { tasks: { requests: { tools: { - call: true + call: {} }, tasks: { get: true, - list: true, + list: {}, result: true } } @@ -1222,11 +1222,11 @@ describe('outputSchema validation', () => { tasks: { requests: { tools: { - call: true + call: {} }, tasks: { get: true, - list: true, + list: {}, result: true } } @@ -1353,7 +1353,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -1363,18 +1363,27 @@ describe('Task-based execution', () => { ); server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } + if (request.params.name === 'test-tool') { const result = { content: [{ type: 'text', text: 'Tool executed successfully!' }] }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; } @@ -1404,18 +1413,18 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Client creates task on server via tool call - const taskId = 'test-task-create'; const pendingRequest = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { task: { - taskId, - keepAlive: 60000 + ttl: 60000 } }); await pendingRequest.result(); - // Verify task was created successfully - const task = await client.getTask({ taskId }); + // Verify task was created successfully by listing tasks + const taskList = await client.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const task = taskList.tasks[0]; expect(task.status).toBe('completed'); }); @@ -1431,7 +1440,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -1441,18 +1450,27 @@ describe('Task-based execution', () => { ); server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } + if (request.params.name === 'test-tool') { const result = { content: [{ type: 'text', text: 'Success!' }] }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; } @@ -1482,16 +1500,17 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create a task - const taskId = 'test-task-get'; const pending = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { - task: { taskId, keepAlive: 60000 } + task: { ttl: 60000 } }); await pending.result(); - // Query task status - const task = await client.getTask({ taskId }); + // Query task status by listing tasks and getting the first one + const taskList = await client.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const task = taskList.tasks[0]; expect(task).toBeDefined(); - expect(task.taskId).toBe(taskId); + expect(task.taskId).toBeDefined(); expect(task.status).toBe('completed'); }); @@ -1507,8 +1526,8 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true, - list: true + call: {}, + list: {} } } } @@ -1518,18 +1537,27 @@ describe('Task-based execution', () => { ); server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } + if (request.params.name === 'test-tool') { const result = { content: [{ type: 'text', text: 'Result data!' }] }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; } @@ -1559,13 +1587,15 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create a task - const taskId = 'test-task-result'; const pending = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { - task: { taskId, keepAlive: 60000 } + task: { ttl: 60000 } }); await pending.result(); - // Query task result + // Get the task ID from the task list and query task result + const taskList = await client.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const taskId = taskList.tasks[0].taskId; const result = await client.getTaskResult({ taskId }, CallToolResultSchema); expect(result.content).toEqual([{ type: 'text', text: 'Result data!' }]); }); @@ -1582,7 +1612,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -1592,18 +1622,26 @@ describe('Task-based execution', () => { ); server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } if (request.params.name === 'test-tool') { const result = { content: [{ type: 'text', text: 'Success!' }] }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; } @@ -1633,19 +1671,26 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create multiple tasks - const taskIds = ['task-list-1', 'task-list-2']; + const createdTaskIds: string[] = []; - for (const taskId of taskIds) { + for (let i = 0; i < 2; i++) { const pending = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { - task: { taskId, keepAlive: 60000 } + task: { ttl: 60000 } }); await pending.result(); + + // Get the task ID from the task list + const taskList = await client.listTasks(); + const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); + if (newTask) { + createdTaskIds.push(newTask.taskId); + } } // Query task list const taskList = await client.listTasks(); expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); - for (const taskId of taskIds) { + for (const taskId of createdTaskIds) { expect(taskList.tasks).toContainEqual( expect.objectContaining({ taskId, @@ -1679,7 +1724,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1688,19 +1733,27 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } const result = { action: 'accept', content: { username: 'list-user' } }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; }); @@ -1715,7 +1768,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1728,7 +1781,6 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Server creates task on client via elicitation - const taskId = 'elicit-task-create'; const ElicitResultSchema = z.object({ action: z.enum(['accept', 'decline', 'cancel']), content: z.record(z.unknown()).optional() @@ -1749,11 +1801,16 @@ describe('Task-based execution', () => { } }, ElicitResultSchema, - { task: { taskId, keepAlive: 60000 } } + { task: { ttl: 60000 } } ); await pendingRequest.result(); + // Get the task ID from the task list since it's generated automatically + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const taskId = taskList.tasks[0].taskId; + // Verify task was created const task = await server.getTask({ taskId }); expect(task.status).toBe('completed'); @@ -1771,7 +1828,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1780,19 +1837,27 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } const result = { action: 'accept', content: { username: 'list-user' } }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; }); @@ -1807,7 +1872,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1820,7 +1885,6 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create a task on client - const taskId = 'elicit-task-get'; const ElicitResultSchema = z.object({ action: z.enum(['accept', 'decline', 'cancel']), content: z.record(z.unknown()).optional() @@ -1838,10 +1902,15 @@ describe('Task-based execution', () => { } }, ElicitResultSchema, - { task: { taskId, keepAlive: 60000 } } + { task: { ttl: 60000 } } ); await pending.result(); + // Get the task ID from the task list since it's generated automatically + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const taskId = taskList.tasks[0].taskId; + // Query task status const task = await server.getTask({ taskId }); expect(task).toBeDefined(); @@ -1861,7 +1930,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1870,19 +1939,27 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } const result = { action: 'accept', content: { username: 'result-user' } }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; }); @@ -1897,7 +1974,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1910,7 +1987,6 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create a task on client - const taskId = 'elicit-task-result'; const ElicitResultSchema = z.object({ action: z.enum(['accept', 'decline', 'cancel']), content: z.record(z.unknown()).optional() @@ -1928,10 +2004,15 @@ describe('Task-based execution', () => { } }, ElicitResultSchema, - { task: { taskId, keepAlive: 60000 } } + { task: { ttl: 60000 } } ); await pending.result(); + // Get the task ID from the task list since it's generated automatically + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const taskId = taskList.tasks[0].taskId; + // Query task result const result = await server.getTaskResult({ taskId }, ElicitResultSchema); expect(result.action).toBe('accept'); @@ -1950,7 +2031,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1959,19 +2040,27 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } const result = { action: 'accept', content: { username: 'list-user' } }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; }); @@ -1986,7 +2075,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -2004,8 +2093,8 @@ describe('Task-based execution', () => { content: z.record(z.unknown()).optional() }); - const taskIds = ['elicit-list-1', 'elicit-list-2']; - for (const taskId of taskIds) { + const createdTaskIds: string[] = []; + for (let i = 0; i < 2; i++) { const pending = server.beginRequest( { method: 'elicitation/create', @@ -2018,15 +2107,22 @@ describe('Task-based execution', () => { } }, ElicitResultSchema, - { task: { taskId, keepAlive: 60000 } } + { task: { ttl: 60000 } } ); await pending.result(); + + // Get the task ID from the task list + const taskList = await server.listTasks(); + const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); + if (newTask) { + createdTaskIds.push(newTask.taskId); + } } // Query task list const taskList = await server.listTasks(); expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); - for (const taskId of taskIds) { + for (const taskId of createdTaskIds) { expect(taskList.tasks).toContainEqual( expect.objectContaining({ taskId, @@ -2051,7 +2147,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -2061,18 +2157,26 @@ describe('Task-based execution', () => { ); server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } if (request.params.name === 'test-tool') { const result = { content: [{ type: 'text', text: `Result for ${request.params.arguments?.id || 'unknown'}` }] }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; } @@ -2102,7 +2206,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -2115,19 +2219,26 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create multiple tasks - const taskIds = ['task-1', 'task-2', 'task-3']; + const createdTaskIds: string[] = []; - for (const taskId of taskIds) { - const pending = client.beginCallTool({ name: 'test-tool', arguments: { id: taskId } }, CallToolResultSchema, { - task: { taskId, keepAlive: 60000 } + for (let i = 0; i < 3; i++) { + const pending = client.beginCallTool({ name: 'test-tool', arguments: { id: `task-${i + 1}` } }, CallToolResultSchema, { + task: { ttl: 60000 } }); await pending.result(); + + // Get the task ID from the task list + const taskList = await client.listTasks(); + const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); + if (newTask) { + createdTaskIds.push(newTask.taskId); + } } // List all tasks without cursor const firstPage = await client.listTasks(); expect(firstPage.tasks.length).toBeGreaterThan(0); - expect(firstPage.tasks.map(t => t.taskId)).toEqual(expect.arrayContaining(taskIds)); + expect(firstPage.tasks.map(t => t.taskId)).toEqual(expect.arrayContaining(createdTaskIds)); // If there's a cursor, test pagination if (firstPage.nextCursor) { @@ -2164,7 +2275,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -2183,7 +2294,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -2211,7 +2322,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -2230,7 +2341,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -2258,7 +2369,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -2282,7 +2393,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -2313,7 +2424,7 @@ test('should respect server task capabilities', async () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -2358,7 +2469,7 @@ test('should respect server task capabilities', async () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -2366,7 +2477,7 @@ test('should respect server task capabilities', async () => { // These should work because server supports tasks const pendingRequest = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { - task: { taskId: 'test-task', keepAlive: 60000 } + task: { ttl: 60000 } }); await expect(pendingRequest.result()).resolves.not.toThrow(); await expect(client.listTasks()).resolves.not.toThrow(); diff --git a/src/client/index.ts b/src/client/index.ts index 63dae7ef0..e5dfe61b8 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,7 +1,7 @@ import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; import type { Transport } from '../shared/transport.js'; import { PendingRequest } from '../shared/request.js'; -import { v4 as uuidv4 } from '@lukeed/uuid'; + import { type CallToolRequest, CallToolResultSchema, @@ -538,11 +538,11 @@ export class Client< resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, options?: RequestOptions ): PendingRequest { - // Automatically add task metadata if not provided + // Add task creation parameters if server supports it and not explicitly provided const optionsWithTask = { ...options, // We check the server capabilities in auto-assignment, but assume the caller knows what they're doing if they pass this explicitly - task: options?.task ?? (this._serverCapabilities?.tasks?.requests?.tools?.call ? { taskId: uuidv4() } : undefined) + task: options?.task ?? (this._serverCapabilities?.tasks?.requests?.tools?.call ? {} : undefined) }; return this.beginRequest({ method: 'tools/call', params }, resultSchema, optionsWithTask); } diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index 069c21c73..c55e105b2 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -822,7 +822,7 @@ async function callToolTask(name: string, args: Record): Promis { task: { taskId, - keepAlive: 60000 // Keep results for 60 seconds + ttl: 60000 // Keep results for 60 seconds } } ); @@ -836,7 +836,7 @@ async function callToolTask(name: string, args: Record): Promis }, onTaskStatus: task => { if (lastStatus !== task.status) { - console.log(` ${task.status}${task.error ? ` - ${task.error}` : ''}`); + console.log(` ${task.status}${task.statusMessage ? ` - ${task.statusMessage}` : ''}`); } lastStatus = task.status; } diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 2c5066eca..043bda15e 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -39,7 +39,7 @@ const getServer = () => { websiteUrl: 'https://github.com/modelcontextprotocol/typescript-sdk' }, { - capabilities: { logging: {}, tasks: { requests: { tools: { call: true } } } }, + capabilities: { logging: {}, tasks: { requests: { tools: { call: {} } } } }, taskStore // Enable task support } ); @@ -464,7 +464,21 @@ const getServer = () => { } }, { - async createTask({ duration }, { taskId, taskStore, taskRequestedKeepAlive }) { + async createTask({ duration }, { taskStore, taskRequestedTtl }) { + // Generate a simple task ID (in production, use a more secure method) + const taskId = `task-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`; + + // Create the task + const task = await taskStore.createTask( + { + taskId, + ttl: taskRequestedTtl, + pollInterval: 100 + }, + 0, + { method: 'tools/call', params: { name: 'createTask', arguments: { duration } } } + ); + // Simulate out-of-band work (async () => { await new Promise(resolve => setTimeout(resolve, duration)); @@ -477,11 +491,11 @@ const getServer = () => { ] }); })(); - return await taskStore.createTask({ - taskId, - keepAlive: taskRequestedKeepAlive, - pollInterval: 100 - }); + + // Return CreateTaskResult with the created task + return { + task + }; }, async getTask(_args, { taskId, taskStore }) { const task = await taskStore.getTask(taskId); diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts index c0540783b..1e61d16b2 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -1,3 +1,4 @@ +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; import { InMemoryTaskStore } from './inMemoryTaskStore.js'; import { TaskMetadata, Request } from '../../types.js'; @@ -16,7 +17,7 @@ describe('InMemoryTaskStore', () => { it('should create a new task with working status', async () => { const metadata: TaskMetadata = { taskId: 'task-1', - keepAlive: 60000 + ttl: 60000 }; const request: Request = { method: 'tools/call', @@ -29,11 +30,11 @@ describe('InMemoryTaskStore', () => { expect(task).toBeDefined(); expect(task?.taskId).toBe('task-1'); expect(task?.status).toBe('working'); - expect(task?.keepAlive).toBe(60000); + expect(task?.ttl).toBe(60000); expect(task?.pollInterval).toBe(500); }); - it('should create task without keepAlive', async () => { + it('should create task without ttl', async () => { const metadata: TaskMetadata = { taskId: 'task-no-keepalive' }; @@ -46,7 +47,7 @@ describe('InMemoryTaskStore', () => { const task = await store.getTask('task-no-keepalive'); expect(task).toBeDefined(); - expect(task?.keepAlive).toBeNull(); + expect(task?.ttl).toBeNull(); }); it('should reject duplicate taskId', async () => { @@ -123,7 +124,7 @@ describe('InMemoryTaskStore', () => { const task = await store.getTask('status-test'); expect(task?.status).toBe('failed'); - expect(task?.error).toBe('Something went wrong'); + expect(task?.statusMessage).toBe('Something went wrong'); }); it('should update task status to cancelled', async () => { @@ -142,7 +143,7 @@ describe('InMemoryTaskStore', () => { beforeEach(async () => { const metadata: TaskMetadata = { taskId: 'result-test', - keepAlive: 60000 + ttl: 60000 }; await store.createTask(metadata, 333, { method: 'tools/call', @@ -205,19 +206,19 @@ describe('InMemoryTaskStore', () => { }); }); - describe('keepAlive cleanup', () => { + describe('ttl cleanup', () => { beforeEach(() => { - jest.useFakeTimers(); + vi.useFakeTimers(); }); afterEach(() => { - jest.useRealTimers(); + vi.useRealTimers(); }); - it('should cleanup task after keepAlive duration', async () => { + it('should cleanup task after ttl duration', async () => { const metadata: TaskMetadata = { taskId: 'cleanup-test', - keepAlive: 1000 + ttl: 1000 }; await store.createTask(metadata, 666, { method: 'tools/call', @@ -228,8 +229,8 @@ describe('InMemoryTaskStore', () => { let task = await store.getTask('cleanup-test'); expect(task).toBeDefined(); - // Fast-forward past keepAlive - jest.advanceTimersByTime(1001); + // Fast-forward past ttl + vi.advanceTimersByTime(1001); // Task should be cleaned up task = await store.getTask('cleanup-test'); @@ -239,7 +240,7 @@ describe('InMemoryTaskStore', () => { it('should reset cleanup timer when result is stored', async () => { const metadata: TaskMetadata = { taskId: 'reset-cleanup', - keepAlive: 1000 + ttl: 1000 }; await store.createTask(metadata, 777, { method: 'tools/call', @@ -247,7 +248,7 @@ describe('InMemoryTaskStore', () => { }); // Fast-forward 500ms - jest.advanceTimersByTime(500); + vi.advanceTimersByTime(500); // Store result (should reset timer) await store.storeTaskResult('reset-cleanup', { @@ -255,21 +256,21 @@ describe('InMemoryTaskStore', () => { }); // Fast-forward another 500ms (total 1000ms since creation, but timer was reset) - jest.advanceTimersByTime(500); + vi.advanceTimersByTime(500); // Task should still exist const task = await store.getTask('reset-cleanup'); expect(task).toBeDefined(); // Fast-forward remaining time - jest.advanceTimersByTime(501); + vi.advanceTimersByTime(501); // Now task should be cleaned up const cleanedTask = await store.getTask('reset-cleanup'); expect(cleanedTask).toBeNull(); }); - it('should not cleanup tasks without keepAlive', async () => { + it('should not cleanup tasks without ttl', async () => { const metadata: TaskMetadata = { taskId: 'no-cleanup' }; @@ -279,7 +280,7 @@ describe('InMemoryTaskStore', () => { }); // Fast-forward a long time - jest.advanceTimersByTime(100000); + vi.advanceTimersByTime(100000); // Task should still exist const task = await store.getTask('no-cleanup'); @@ -289,7 +290,7 @@ describe('InMemoryTaskStore', () => { it('should start cleanup timer when task reaches terminal state', async () => { const metadata: TaskMetadata = { taskId: 'terminal-cleanup', - keepAlive: 1000 + ttl: 1000 }; await store.createTask(metadata, 999, { method: 'tools/call', @@ -297,7 +298,7 @@ describe('InMemoryTaskStore', () => { }); // Task in non-terminal state, fast-forward - jest.advanceTimersByTime(1001); + vi.advanceTimersByTime(1001); // Task should be cleaned up let task = await store.getTask('terminal-cleanup'); @@ -306,7 +307,7 @@ describe('InMemoryTaskStore', () => { // Create another task const metadata2: TaskMetadata = { taskId: 'terminal-cleanup-2', - keepAlive: 2000 + ttl: 2000 }; await store.createTask(metadata2, 1000, { method: 'tools/call', @@ -316,8 +317,8 @@ describe('InMemoryTaskStore', () => { // Update to terminal state await store.updateTaskStatus('terminal-cleanup-2', 'completed'); - // Fast-forward past original keepAlive - jest.advanceTimersByTime(2001); + // Fast-forward past original ttl + vi.advanceTimersByTime(2001); // Task should be cleaned up task = await store.getTask('terminal-cleanup-2'); @@ -426,11 +427,11 @@ describe('InMemoryTaskStore', () => { describe('cleanup', () => { it('should clear all timers and tasks', async () => { - await store.createTask({ taskId: 'task-1', keepAlive: 1000 }, 1, { + await store.createTask({ taskId: 'task-1', ttl: 1000 }, 1, { method: 'tools/call', params: {} }); - await store.createTask({ taskId: 'task-2', keepAlive: 2000 }, 2, { + await store.createTask({ taskId: 'task-2', ttl: 2000 }, 2, { method: 'tools/call', params: {} }); @@ -461,10 +462,10 @@ describe('InMemoryTaskStore', () => { await expect(store.deleteTask('non-existent')).rejects.toThrow('Task with ID non-existent not found'); }); - it('should clear cleanup timer when deleting task with keepAlive', async () => { - jest.useFakeTimers(); + it('should clear cleanup timer when deleting task with ttl', async () => { + vi.useFakeTimers(); - await store.createTask({ taskId: 'task-with-timer', keepAlive: 1000 }, 1, { + await store.createTask({ taskId: 'task-with-timer', ttl: 1000 }, 1, { method: 'tools/call', params: {} }); @@ -473,13 +474,13 @@ describe('InMemoryTaskStore', () => { await store.deleteTask('task-with-timer'); - // Fast-forward past keepAlive time - jest.advanceTimersByTime(1001); + // Fast-forward past ttl time + vi.advanceTimersByTime(1001); // Task should not exist (it was deleted immediately, not cleaned up by timer) expect(await store.getTask('task-with-timer')).toBeNull(); - jest.useRealTimers(); + vi.useRealTimers(); }); it('should delete task with result', async () => { diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index cb6e49b9e..c702f0e9d 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -12,7 +12,7 @@ interface StoredTask { * A simple in-memory implementation of TaskStore for demonstration purposes. * * This implementation stores all tasks in memory and provides automatic cleanup - * based on the keepAlive duration specified in the task metadata. + * based on the ttl duration specified in the task metadata. * * Note: This is not suitable for production use as all data is lost on restart. * For production, consider implementing TaskStore with a database or distributed cache. @@ -31,7 +31,8 @@ export class InMemoryTaskStore implements TaskStore { const task: Task = { taskId, status: 'working', - keepAlive: metadata.keepAlive ?? null, + ttl: metadata.ttl ?? null, + createdAt: new Date().toISOString(), pollInterval: metadata.pollInterval ?? 500 }; @@ -41,12 +42,12 @@ export class InMemoryTaskStore implements TaskStore { requestId }); - // Schedule cleanup if keepAlive is specified - if (metadata.keepAlive) { + // Schedule cleanup if ttl is specified + if (metadata.ttl) { const timer = setTimeout(() => { this.tasks.delete(taskId); this.cleanupTimers.delete(taskId); - }, metadata.keepAlive); + }, metadata.ttl); this.cleanupTimers.set(taskId, timer); } @@ -68,8 +69,8 @@ export class InMemoryTaskStore implements TaskStore { stored.result = result; stored.task.status = 'completed'; - // Reset cleanup timer to start from now (if keepAlive is set) - if (stored.task.keepAlive) { + // Reset cleanup timer to start from now (if ttl is set) + if (stored.task.ttl) { const existingTimer = this.cleanupTimers.get(taskId); if (existingTimer) { clearTimeout(existingTimer); @@ -78,7 +79,7 @@ export class InMemoryTaskStore implements TaskStore { const timer = setTimeout(() => { this.tasks.delete(taskId); this.cleanupTimers.delete(taskId); - }, stored.task.keepAlive); + }, stored.task.ttl); this.cleanupTimers.set(taskId, timer); } @@ -97,19 +98,19 @@ export class InMemoryTaskStore implements TaskStore { return stored.result; } - async updateTaskStatus(taskId: string, status: Task['status'], error?: string, _sessionId?: string): Promise { + async updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string, _sessionId?: string): Promise { const stored = this.tasks.get(taskId); if (!stored) { throw new Error(`Task with ID ${taskId} not found`); } stored.task.status = status; - if (error) { - stored.task.error = error; + if (statusMessage) { + stored.task.statusMessage = statusMessage; } - // If task is in a terminal state and has keepAlive, start cleanup timer - if (isTerminal(status) && stored.task.keepAlive) { + // If task is in a terminal state and has ttl, start cleanup timer + if (isTerminal(status) && stored.task.ttl) { const existingTimer = this.cleanupTimers.get(taskId); if (existingTimer) { clearTimeout(existingTimer); @@ -118,7 +119,7 @@ export class InMemoryTaskStore implements TaskStore { const timer = setTimeout(() => { this.tasks.delete(taskId); this.cleanupTimers.delete(taskId); - }, stored.task.keepAlive); + }, stored.task.ttl); this.cleanupTimers.set(taskId, timer); } diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 1b765668a..f01e35864 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -967,7 +967,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -978,20 +978,29 @@ describe('Task-based execution', () => { // Set up a tool handler that simulates some work server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } + if (request.params.name === 'test-tool') { // Simulate some async work await new Promise(resolve => setTimeout(resolve, 10)); const result = { content: [{ type: 'text', text: 'Tool executed successfully!' }] }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; } @@ -1021,7 +1030,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -1034,17 +1043,20 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Use beginCallTool to create a task - const taskId = 'test-task-123'; const pendingRequest = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { task: { - taskId, - keepAlive: 60000 + ttl: 60000 } }); // Wait for the task to complete await pendingRequest.result(); + // Get the task ID from the task list since it's generated automatically + const taskList = await client.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const taskId = taskList.tasks[0].taskId; + // Verify we can retrieve the task const task = await client.getTask({ taskId }); expect(task).toBeDefined(); @@ -1104,7 +1116,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -1135,7 +1147,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -1155,7 +1167,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1168,11 +1180,19 @@ describe('Task-based execution', () => { // Set up client elicitation handler client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } // Capture the request to verify metadata later @@ -1188,11 +1208,19 @@ describe('Task-based execution', () => { // Set up server tool that makes a nested elicitation request server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } if (request.params.name === 'collect-info') { @@ -1225,8 +1253,8 @@ describe('Task-based execution', () => { } ] }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; } @@ -1250,24 +1278,23 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // Call tool WITH task metadata - const taskId = 'test-task-456'; + // Call tool WITH task creation const pendingRequest = client.beginCallTool({ name: 'collect-info', arguments: {} }, CallToolResultSchema, { task: { - taskId, - keepAlive: 60000 + ttl: 60000 } }); // Wait for completion await pendingRequest.result(); - // Verify the nested elicitation request received the related-task metadata + // Verify the nested elicitation request was made (related-task metadata is no longer automatically attached) expect(capturedElicitRequest).toBeDefined(); - expect(capturedElicitRequest!.params._meta).toBeDefined(); - expect(capturedElicitRequest!.params._meta?.['modelcontextprotocol.io/related-task']).toEqual({ - taskId: 'test-task-456' - }); + + // Get the task ID from the task list since it's generated automatically + const taskList = await client.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const taskId = taskList.tasks[0].taskId; // Verify tool result was correct const result = await client.getTaskResult({ taskId }, CallToolResultSchema); @@ -1305,7 +1332,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1314,19 +1341,27 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } const result = { action: 'accept', content: { username: 'server-test-user', confirmed: true } }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; }); @@ -1341,7 +1376,6 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Server creates task on client via elicitation - const taskId = 'server-elicit-create'; const ElicitResultSchema = z.object({ action: z.enum(['accept', 'decline', 'cancel']), content: z.record(z.unknown()).optional() @@ -1363,11 +1397,16 @@ describe('Task-based execution', () => { } }, ElicitResultSchema, - { task: { taskId, keepAlive: 60000 } } + { task: { ttl: 60000 } } ); await pendingRequest.result(); + // Get the task ID from the task list since it's generated automatically + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const taskId = taskList.tasks[0].taskId; + // Verify task was created const task = await server.getTask({ taskId }); expect(task.status).toBe('completed'); @@ -1385,7 +1424,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1394,19 +1433,27 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } const result = { action: 'accept', content: { username: 'list-user' } }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; }); @@ -1421,7 +1468,6 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create task - const taskId = 'server-elicit-get'; const ElicitResultSchema = z.object({ action: z.enum(['accept', 'decline', 'cancel']), content: z.record(z.unknown()).optional() @@ -1439,10 +1485,15 @@ describe('Task-based execution', () => { } }, ElicitResultSchema, - { task: { taskId, keepAlive: 60000 } } + { task: { ttl: 60000 } } ); await pending.result(); + // Get the task ID from the task list since it's generated automatically + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const taskId = taskList.tasks[0].taskId; + // Query task const task = await server.getTask({ taskId }); expect(task).toBeDefined(); @@ -1462,7 +1513,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1471,19 +1522,27 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } const result = { action: 'accept', content: { username: 'result-user', confirmed: true } }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; }); @@ -1498,7 +1557,6 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create task - const taskId = 'server-elicit-result'; const ElicitResultSchema = z.object({ action: z.enum(['accept', 'decline', 'cancel']), content: z.record(z.unknown()).optional() @@ -1519,10 +1577,15 @@ describe('Task-based execution', () => { } }, ElicitResultSchema, - { task: { taskId, keepAlive: 60000 } } + { task: { ttl: 60000 } } ); await pending.result(); + // Get the task ID from the task list since it's generated automatically + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const taskId = taskList.tasks[0].taskId; + // Query result const result = await server.getTaskResult({ taskId }, ElicitResultSchema); expect(result.action).toBe('accept'); @@ -1541,7 +1604,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1550,19 +1613,27 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } const result = { action: 'accept', content: { username: 'list-user' } }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; }); @@ -1577,7 +1648,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1595,8 +1666,8 @@ describe('Task-based execution', () => { content: z.record(z.unknown()).optional() }); - const taskIds = ['server-elicit-list-1', 'server-elicit-list-2']; - for (const taskId of taskIds) { + const createdTaskIds: string[] = []; + for (let i = 0; i < 2; i++) { const pending = server.beginRequest( { method: 'elicitation/create', @@ -1609,15 +1680,22 @@ describe('Task-based execution', () => { } }, ElicitResultSchema, - { task: { taskId, keepAlive: 60000 } } + { task: { ttl: 60000 } } ); await pending.result(); + + // Get the task ID from the task list + const taskList = await server.listTasks(); + const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); + if (newTask) { + createdTaskIds.push(newTask.taskId); + } } // Query task list const taskList = await server.listTasks(); expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); - for (const taskId of taskIds) { + for (const taskId of createdTaskIds) { expect(taskList.tasks).toContainEqual( expect.objectContaining({ taskId, @@ -1642,7 +1720,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -1653,11 +1731,19 @@ describe('Task-based execution', () => { // Set up a tool handler with variable delay server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } if (request.params.name === 'async-tool') { const delay = (request.params.arguments?.delay as number) || 10; @@ -1665,8 +1751,8 @@ describe('Task-based execution', () => { const result = { content: [{ type: 'text', text: `Completed task ${request.params.arguments?.taskNum || 'unknown'}` }] }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; } @@ -1699,7 +1785,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -1712,16 +1798,20 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create multiple tasks concurrently - const taskIds = ['concurrent-1', 'concurrent-2', 'concurrent-3', 'concurrent-4']; - const pendingRequests = taskIds.map((taskId, index) => + const pendingRequests = Array.from({ length: 4 }, (_, index) => client.beginCallTool({ name: 'async-tool', arguments: { delay: 10 + index * 5, taskNum: index + 1 } }, CallToolResultSchema, { - task: { taskId, keepAlive: 60000 } + task: { ttl: 60000 } }) ); // Wait for all tasks to complete await Promise.all(pendingRequests.map(p => p.result())); + // Get all task IDs from the task list + const taskList = await client.listTasks(); + expect(taskList.tasks.length).toBeGreaterThanOrEqual(4); + const taskIds = taskList.tasks.map(t => t.taskId); + // Verify all tasks completed successfully for (let i = 0; i < taskIds.length; i++) { const task = await client.getTask({ taskId: taskIds[i] }); @@ -1733,9 +1823,9 @@ describe('Task-based execution', () => { } // Verify listTasks returns all tasks - const taskList = await client.listTasks(); + const finalTaskList = await client.listTasks(); for (const taskId of taskIds) { - expect(taskList.tasks).toContainEqual(expect.objectContaining({ taskId })); + expect(finalTaskList.tasks).toContainEqual(expect.objectContaining({ taskId })); } // Cleanup @@ -1768,7 +1858,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -1787,7 +1877,7 @@ describe('Task-based execution', () => { tasks: { requests: { tools: { - call: true + call: {} } } } @@ -1815,7 +1905,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1839,7 +1929,7 @@ describe('Task-based execution', () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1872,7 +1962,7 @@ test('should respect client task capabilities', async () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1881,19 +1971,27 @@ test('should respect client task capabilities', async () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (_request, extra) => { - if (extra.taskId) { - await extra.taskStore?.createTask({ - taskId: extra.taskId, - keepAlive: extra.taskRequestedKeepAlive - }); + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + await extra.taskStore.createTask( + { + taskId, + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); } const result = { action: 'accept', content: { username: 'test-user' } }; - if (extra.taskId) { - await extra.taskStore?.storeTaskResult(extra.taskId, result); + if (taskId && extra.taskStore) { + await extra.taskStore.storeTaskResult(taskId, result); } return result; }); @@ -1908,7 +2006,7 @@ test('should respect client task capabilities', async () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1927,7 +2025,7 @@ test('should respect client task capabilities', async () => { tasks: { requests: { elicitation: { - create: true + create: {} } } } @@ -1951,11 +2049,16 @@ test('should respect client task capabilities', async () => { } }, ElicitResultSchema, - { task: { taskId: 'test-task', keepAlive: 60000 } } + { task: { ttl: 60000 } } ); await expect(pendingRequest.result()).resolves.not.toThrow(); await expect(server.listTasks()).resolves.not.toThrow(); - await expect(server.getTask({ taskId: 'test-task' })).resolves.not.toThrow(); + + // Get the task ID from the task list since it's generated automatically + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const taskId = taskList.tasks[0].taskId; + await expect(server.getTask({ taskId })).resolves.not.toThrow(); // This should throw because client doesn't support task creation for sampling/createMessage await expect( diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 24d001e9a..428c25ae8 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -768,7 +768,7 @@ describe('tool()', () => { 'test', 'A tool with everything', { name: z.string() }, - { title: 'Complete Test Tool', readOnlyHint: true, openWorldHint: false, taskHint: false }, + { title: 'Complete Test Tool', readOnlyHint: true, openWorldHint: false, taskHint: 'never' }, async ({ name }) => ({ content: [{ type: 'text', text: `Hello, ${name}!` }] }) @@ -783,7 +783,7 @@ describe('tool()', () => { title: 'Complete Test Tool', readOnlyHint: true, openWorldHint: false, - taskHint: false + taskHint: 'never' } }, async ({ name }) => ({ @@ -808,7 +808,7 @@ describe('tool()', () => { title: 'Complete Test Tool', readOnlyHint: true, openWorldHint: false, - taskHint: false + taskHint: 'never' }); expect(result.tools[1].name).toBe('test (new api)'); expect(result.tools[1].description).toBe('A tool with everything'); @@ -837,7 +837,7 @@ describe('tool()', () => { title: 'Complete Test Tool with empty params', readOnlyHint: true, openWorldHint: false, - taskHint: false + taskHint: 'never' }, async () => ({ content: [{ type: 'text', text: 'Test response' }] @@ -853,7 +853,7 @@ describe('tool()', () => { title: 'Complete Test Tool with empty params', readOnlyHint: true, openWorldHint: false, - taskHint: false + taskHint: 'never' } }, async () => ({ @@ -878,7 +878,7 @@ describe('tool()', () => { title: 'Complete Test Tool with empty params', readOnlyHint: true, openWorldHint: false, - taskHint: false + taskHint: 'never' }); expect(result.tools[1].name).toBe('test (new api)'); expect(result.tools[1].description).toBe('A tool with everything but empty params'); diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 81fa16150..a1bcd3d9f 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -34,7 +34,6 @@ import { CreateTaskResult, GetTaskResult, Result, - TASK_META_KEY, CompleteRequestPrompt, CompleteRequestResourceTemplate, assertCompleteRequestPrompt, @@ -44,7 +43,7 @@ import { Completable, CompletableDef } from './completable.js'; import { UriTemplate, Variables } from '../shared/uriTemplate.js'; import { RequestHandlerExtra, RequestTaskStore } from '../shared/protocol.js'; import { Transport } from '../shared/transport.js'; -import { isTerminal } from '../shared/task.js'; + import { validateAndWarnToolName } from '../shared/toolNameValidation.js'; /** @@ -133,10 +132,10 @@ export class McpServer { }) ); - this.server.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { + this.server.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { const tool = this._registeredTools[request.params.name]; - let result: CallToolResult; + let result: CallToolResult | CreateTaskResult; try { if (!tool) { @@ -147,8 +146,8 @@ export class McpServer { throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} disabled`); } + const isTaskRequest = !!request.params.task; if (tool.inputSchema) { - const cb = tool.callback as ToolCallback; const parseResult = await tool.inputSchema.safeParseAsync(request.params.arguments); if (!parseResult.success) { throw new McpError( @@ -159,10 +158,40 @@ export class McpServer { const args = parseResult.data; - result = await Promise.resolve(cb(args, extra)); + const handler = tool.handler as AnyToolHandler; + if ('createTask' in handler) { + const cb = handler.createTask; + if (!extra.taskStore) { + throw new Error('No task store provided.'); + } + + // Needed to show the compiler this field exists + const taskExtra = { ...extra, taskStore: extra.taskStore }; + result = await Promise.resolve(cb(args, taskExtra)); + } else { + const cb = handler; + result = await Promise.resolve(cb(args, extra)); + } } else { - const cb = tool.callback as ToolCallback; - result = await Promise.resolve(cb(extra)); + const handler = tool.handler as AnyToolHandler; + if ('createTask' in handler) { + const cb = handler.createTask; + if (!extra.taskStore) { + throw new Error('No task store provided.'); + } + + // Needed to show the compiler this field exists + const taskExtra = { ...extra, taskStore: extra.taskStore }; + result = await Promise.resolve(cb(taskExtra)); + } else { + const cb = handler; + result = await Promise.resolve(cb(extra)); + } + } + + if (isTaskRequest) { + // Return the CreateTaskResult immediately + return result; } if (tool.outputSchema && !result.isError) { @@ -668,7 +697,7 @@ export class McpServer { outputSchema: ZodRawShape | ZodType | undefined, annotations: ToolAnnotations | undefined, _meta: Record | undefined, - callback: ToolCallback + handler: AnyToolHandler ): RegisteredTool { // Validate tool name according to SEP specification validateAndWarnToolName(name); @@ -680,7 +709,7 @@ export class McpServer { outputSchema: getZodSchemaObject(outputSchema), annotations, _meta, - callback, + handler: handler, enabled: true, disable: () => registeredTool.update({ enabled: false }), enable: () => registeredTool.update({ enabled: true }), @@ -696,7 +725,7 @@ export class McpServer { if (typeof updates.title !== 'undefined') registeredTool.title = updates.title; if (typeof updates.description !== 'undefined') registeredTool.description = updates.description; if (typeof updates.paramsSchema !== 'undefined') registeredTool.inputSchema = z.object(updates.paramsSchema); - if (typeof updates.callback !== 'undefined') registeredTool.callback = updates.callback; + if (typeof updates.callback !== 'undefined') registeredTool.handler = updates.callback; if (typeof updates.annotations !== 'undefined') registeredTool.annotations = updates.annotations; if (typeof updates._meta !== 'undefined') registeredTool._meta = updates._meta; if (typeof updates.enabled !== 'undefined') registeredTool.enabled = updates.enabled; @@ -863,34 +892,16 @@ export class McpServer { }, handler: ToolTaskHandler ): RegisteredTool { - // TODO: Attach to individual request handlers and remove this wrapper - const cb: ToolCallback = (async (...args) => { - const [inputArgs, extra] = args; - - const taskStore = extra.taskStore; - if (!taskStore) { - throw new Error('Task store is not available'); - } - - const taskMetadata = extra._meta?.[TASK_META_KEY]; - const taskId = taskMetadata?.taskId; - if (!taskId) { - throw new Error('No task ID provided'); - } - - // Internal polling to allow using this interface before internals are hooked up - const taskExtra = { ...extra, taskId, taskStore }; - let task = await handler.createTask(inputArgs, taskExtra); - do { - await new Promise(resolve => setTimeout(resolve, task.pollInterval ?? 5000)); - task = await handler.getTask(inputArgs, taskExtra); - } while (!isTerminal(task.status)); - - const result: CallToolResult = await handler.getTaskResult(inputArgs, taskExtra); - return result; - }) as ToolCallback; - - return this.registerTool(name, { ...config, annotations: { ...config.annotations, taskHint: true } }, cb); + return this._createRegisteredTool( + name, + config.title, + config.description, + config.inputSchema, + config.outputSchema, + { ...config.annotations, taskHint: 'always' }, + config._meta, + handler + ); } /** @@ -1082,6 +1093,16 @@ export class ResourceTemplate { } } +export type BaseToolCallback< + SendResultT extends Result, + Extra extends RequestHandlerExtra, + Args extends undefined | ZodRawShape | ZodType +> = Args extends ZodRawShape + ? (args: z.objectOutputType, extra: Extra) => SendResultT | Promise + : Args extends ZodType + ? (args: T, extra: Extra) => SendResultT | Promise + : (extra: Extra) => SendResultT | Promise; + /** * Callback for a tool handler registered with Server.tool(). * @@ -1092,30 +1113,51 @@ export class ResourceTemplate { * - `content` if the tool does not have an outputSchema * - Both fields are optional but typically one should be provided */ -export type ToolCallback = undefined> = Args extends ZodRawShape - ? ( - args: z.objectOutputType, - extra: RequestHandlerExtra - ) => CallToolResult | Promise - : Args extends ZodType - ? (args: T, extra: RequestHandlerExtra) => CallToolResult | Promise - : (extra: RequestHandlerExtra) => CallToolResult | Promise; +export type ToolCallback = undefined> = BaseToolCallback< + CallToolResult, + RequestHandlerExtra, + Args +>; -export interface TaskRequestHandlerExtra extends RequestHandlerExtra { +export interface CreateTaskRequestHandlerExtra + extends RequestHandlerExtra { + taskStore: RequestTaskStore; +} + +export interface TaskRequestHandlerExtra + extends RequestHandlerExtra { taskId: string; taskStore: RequestTaskStore; } -export type TaskRequestHandler = Args extends ZodRawShape - ? (args: z.objectOutputType, extra: TaskRequestHandlerExtra) => SendResultT | Promise - : (extra: TaskRequestHandlerExtra) => SendResultT | Promise; +export type CreateTaskRequestHandler< + SendResultT extends Result, + Args extends undefined | ZodRawShape | ZodType = undefined +> = BaseToolCallback, Args>; + +export type TaskRequestHandler< + SendResultT extends Result, + Args extends undefined | ZodRawShape | ZodType = undefined +> = BaseToolCallback, Args>; -export interface ToolTaskHandler { - createTask: TaskRequestHandler; +export interface ToolTaskHandler = undefined> { + createTask: CreateTaskRequestHandler; getTask: TaskRequestHandler; getTaskResult: TaskRequestHandler; } +/** + * Supertype for tool handler callbacks registered with Server.registerTool() and Server.registerToolTask(). + */ +export type AnyToolCallback = undefined> = + | ToolCallback + | TaskRequestHandler; + +/** + * Supertype that can handle both regular tools (simple callback) and task-based tools (task handler object). + */ +export type AnyToolHandler = undefined> = ToolCallback | ToolTaskHandler; + export type RegisteredTool = { title?: string; description?: string; @@ -1123,7 +1165,7 @@ export type RegisteredTool = { outputSchema?: ZodType; annotations?: ToolAnnotations; _meta?: Record; - callback: ToolCallback; + handler: AnyToolHandler; enabled: boolean; enable(): void; disable(): void; diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index d20f91c31..eb0a6f367 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -11,13 +11,12 @@ import { Result, ServerCapabilities, Task, - TASK_META_KEY, TaskMetadata } from '../types.js'; import { Protocol, mergeCapabilities } from './protocol.js'; import { Transport } from './transport.js'; import { TaskStore } from './task.js'; -import { MockInstance } from 'vitest'; +import { MockInstance, vi } from 'vitest'; // Mock Transport class class MockTransport implements Transport { @@ -35,32 +34,33 @@ class MockTransport implements Transport { function createMockTaskStore(options?: { onStatus?: (status: Task['status']) => void; onList?: () => void; -}): TaskStore & { [K in keyof TaskStore]: jest.Mock, Parameters> } { +}): TaskStore & { [K in keyof TaskStore]: MockInstance } { const tasks: Record = {}; return { - createTask: jest.fn((taskMetadata: TaskMetadata, _1: RequestId, _2: Request) => { + createTask: vi.fn((taskMetadata: TaskMetadata, _1: RequestId, _2: Request) => { const task = (tasks[taskMetadata.taskId] = { taskId: taskMetadata.taskId, status: (taskMetadata.status as Task['status'] | undefined) ?? 'working', - keepAlive: taskMetadata.keepAlive ?? null, + ttl: taskMetadata.ttl ?? null, + createdAt: new Date().toISOString(), pollInterval: (taskMetadata.pollInterval as Task['pollInterval'] | undefined) ?? 1000 }); options?.onStatus?.('working'); return Promise.resolve(task); }), - getTask: jest.fn((taskId: string) => { + getTask: vi.fn((taskId: string) => { return Promise.resolve(tasks[taskId] ?? null); }), - updateTaskStatus: jest.fn((taskId, status, error) => { + updateTaskStatus: vi.fn((taskId, status, statusMessage) => { const task = tasks[taskId]; if (task) { task.status = status; - task.error = error; + task.statusMessage = statusMessage; options?.onStatus?.(task.status); } return Promise.resolve(); }), - storeTaskResult: jest.fn((taskId: string, result: Result) => { + storeTaskResult: vi.fn((taskId: string, result: Result) => { const task = tasks[taskId]; if (task) { task.status = 'completed'; @@ -69,21 +69,21 @@ function createMockTaskStore(options?: { } return Promise.resolve(); }), - getTaskResult: jest.fn((taskId: string) => { + getTaskResult: vi.fn((taskId: string) => { const task = tasks[taskId]; if (task?.result) { return Promise.resolve(task.result); } throw new Error('Task result not found'); }), - listTasks: jest.fn(() => { + listTasks: vi.fn(() => { const result = { tasks: Object.values(tasks) }; options?.onList?.(); return Promise.resolve(result); }), - deleteTask: jest.fn((taskId: string) => { + deleteTask: vi.fn((taskId: string) => { if (tasks[taskId]) { delete tasks[taskId]; return Promise.resolve(); @@ -857,11 +857,11 @@ describe('mergeCapabilities', () => { describe('Task-based execution', () => { let protocol: Protocol; let transport: MockTransport; - let sendSpy: jest.SpyInstance; + let sendSpy: MockInstance; beforeEach(() => { transport = new MockTransport(); - sendSpy = jest.spyOn(transport, 'send'); + sendSpy = vi.spyOn(transport, 'send'); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} @@ -872,7 +872,7 @@ describe('Task-based execution', () => { }); describe('beginRequest with task metadata', () => { - it('should inject task metadata into _meta field', async () => { + it('should include task parameters at top level', async () => { await protocol.connect(transport); const request = { @@ -886,8 +886,8 @@ describe('Task-based execution', () => { protocol.beginRequest(request, resultSchema, { task: { - taskId: 'my-task-123', - keepAlive: 30000 + ttl: 30000, + pollInterval: 1000 } }); @@ -896,11 +896,9 @@ describe('Task-based execution', () => { method: 'tools/call', params: { name: 'test-tool', - _meta: { - [TASK_META_KEY]: { - taskId: 'my-task-123', - keepAlive: 30000 - } + task: { + ttl: 30000, + pollInterval: 1000 } } }), @@ -908,7 +906,7 @@ describe('Task-based execution', () => { ); }); - it('should preserve existing _meta when adding task metadata', async () => { + it('should preserve existing _meta and add task parameters at top level', async () => { await protocol.connect(transport); const request = { @@ -927,7 +925,7 @@ describe('Task-based execution', () => { protocol.beginRequest(request, resultSchema, { task: { - taskId: 'my-task-456' + ttl: 60000 } }); @@ -936,10 +934,10 @@ describe('Task-based execution', () => { params: { name: 'test-tool', _meta: { - customField: 'customValue', - [TASK_META_KEY]: { - taskId: 'my-task-456' - } + customField: 'customValue' + }, + task: { + ttl: 60000 } } }), @@ -961,12 +959,12 @@ describe('Task-based execution', () => { const pendingRequest = protocol.beginRequest(request, resultSchema, { task: { - taskId: 'my-task-789' + ttl: 30000 } }); expect(pendingRequest).toBeDefined(); - expect(pendingRequest.taskId).toBe('my-task-789'); + expect(pendingRequest.taskId).toBeUndefined(); // taskId is generated by receiver, not provided by client }); }); @@ -1050,22 +1048,24 @@ describe('Task-based execution', () => { protocol.beginRequest(request, resultSchema, { task: { - taskId: 'my-task-combined' + ttl: 60000, + pollInterval: 1000 }, relatedTask: { taskId: 'parent-task' }, - onprogress: jest.fn() + onprogress: vi.fn() }); expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ params: { name: 'test-tool', + task: { + ttl: 60000, + pollInterval: 1000 + }, _meta: { - [TASK_META_KEY]: { - taskId: 'my-task-combined' - }, [RELATED_TASK_META_KEY]: { taskId: 'parent-task' }, @@ -1079,171 +1079,16 @@ describe('Task-based execution', () => { }); describe('task status transitions', () => { - it('should transition from submitted to working when handler starts', async () => { - const workingProcessed = createLatch(); - const mockTaskStore = createMockTaskStore({ - onStatus: status => { - if (status === 'working') { - workingProcessed.releaseLatch(); - } - } - }); - - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); - - await protocol.connect(transport); - - protocol.setRequestHandler(CallToolRequestSchema, async request => { - await mockTaskStore.createTask( - { - taskId: 'test-task', - keepAlive: 60000 - }, - 1, - request, - undefined - ); - await mockTaskStore.updateTaskStatus('test-task', 'working', undefined, undefined); - return { - result: 'success' - }; - }); - - transport.onmessage?.({ - jsonrpc: '2.0', - id: 1, - method: 'tools/call', - params: { - name: 'test', - arguments: {}, - _meta: { - [TASK_META_KEY]: { - taskId: 'test-task', - keepAlive: 60000 - } - } - } - }); - - await workingProcessed.waitForLatch(); - - expect(mockTaskStore.createTask).toHaveBeenCalledWith( - { taskId: 'test-task', keepAlive: 60000 }, - 1, - { - method: 'tools/call', - params: expect.any(Object) - }, - undefined - ); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working', undefined, undefined); + it('should be handled by tool implementors, not protocol layer', () => { + // Task status management is now the responsibility of tool implementors + expect(true).toBe(true); }); - it('should transition to input_required during extra.sendRequest', async () => { + it('should handle requests with task creation parameters in top-level task field', async () => { + // This test documents that task creation parameters are now in the top-level task field + // rather than in _meta, and that task management is handled by tool implementors const mockTaskStore = createMockTaskStore(); - const responsiveTransport = new MockTransport(); - responsiveTransport.send = jest.fn().mockImplementation(async (message: unknown) => { - if ( - typeof message === 'object' && - message !== null && - 'method' in message && - 'id' in message && - message.method === 'nested/request' && - responsiveTransport.onmessage - ) { - setTimeout(() => { - responsiveTransport.onmessage?.({ - jsonrpc: '2.0', - id: (message as { id: number }).id, - result: { nested: 'response' } - }); - }, 5); - } - }); - - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); - - await protocol.connect(responsiveTransport); - - const capturedUpdateCalls: Array<{ taskId: string; status: string }> = []; - mockTaskStore.updateTaskStatus.mockImplementation((taskId, status) => { - capturedUpdateCalls.push({ taskId, status }); - return Promise.resolve(); - }); - - protocol.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - await mockTaskStore.createTask( - { - taskId: 'test-task', - keepAlive: 60000 - }, - 1, - request - ); - await mockTaskStore.updateTaskStatus('test-task', 'working', undefined, undefined); - await extra.sendRequest({ method: 'nested/request', params: {} }, z.object({ nested: z.string() })); - return { result: 'success' }; - }); - - responsiveTransport.onmessage?.({ - jsonrpc: '2.0', - id: 1, - method: 'tools/call', - params: { - name: 'test', - arguments: {}, - _meta: { - [TASK_META_KEY]: { - taskId: 'test-task', - keepAlive: 60000 - } - } - } - }); - - await new Promise(resolve => setTimeout(resolve, 100)); - - expect(capturedUpdateCalls).toContainEqual({ taskId: 'test-task', status: 'working' }); - expect(capturedUpdateCalls).toContainEqual({ taskId: 'test-task', status: 'input_required' }); - - const inputRequiredIndex = capturedUpdateCalls.findIndex(c => c.status === 'input_required'); - const workingCalls = capturedUpdateCalls.filter(c => c.status === 'working'); - expect(workingCalls).toHaveLength(2); - - let workingCount = 0; - const secondWorkingIndex = capturedUpdateCalls.findIndex(c => { - if (c.status === 'working') { - workingCount++; - return workingCount === 2; - } - return false; - }); - expect(secondWorkingIndex).toBeGreaterThan(inputRequiredIndex); - }); - - it('should mark task as completed when storeTaskResult is called', async () => { - const completeProcessed = createLatch(); - const mockTaskStore = createMockTaskStore({ - onStatus: status => { - if (status === 'completed') { - completeProcessed.releaseLatch(); - } - } - }); - protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} @@ -1255,19 +1100,12 @@ describe('Task-based execution', () => { await protocol.connect(transport); protocol.setRequestHandler(CallToolRequestSchema, async request => { - await mockTaskStore.createTask( - { - taskId: 'test-task', - keepAlive: 60000 - }, - 1, - request - ); - await mockTaskStore.updateTaskStatus('test-task', 'working', undefined, undefined); - await mockTaskStore.storeTaskResult('test-task', { result: 'success' }, undefined); - return { - result: 'success' - }; + // Tool implementor can access task creation parameters from request.params.task + expect(request.params.task).toEqual({ + ttl: 60000, + pollInterval: 1000 + }); + return { result: 'success' }; }); transport.onmessage?.({ @@ -1277,18 +1115,15 @@ describe('Task-based execution', () => { params: { name: 'test', arguments: {}, - _meta: { - [TASK_META_KEY]: { - taskId: 'test-task', - keepAlive: 60000 - } + task: { + ttl: 60000, + pollInterval: 1000 } } }); - await completeProcessed.waitForLatch(); - - expect(mockTaskStore.storeTaskResult).toHaveBeenCalledWith('test-task', { result: 'success' }, undefined); + // Wait for the request to be processed + await new Promise(resolve => setTimeout(resolve, 10)); }); }); @@ -1349,8 +1184,8 @@ describe('Task-based execution', () => { expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(3); expect(sentMessage.result.tasks).toEqual([ - { taskId: 'task-1', status: 'completed', keepAlive: null, pollInterval: 500 }, - { taskId: 'task-2', status: 'working', keepAlive: 60000, pollInterval: 1000 } + { taskId: 'task-1', status: 'completed', ttl: null, createdAt: expect.any(String), pollInterval: 500 }, + { taskId: 'task-2', status: 'working', ttl: null, createdAt: expect.any(String), pollInterval: 1000 } ]); expect(sentMessage.result._meta).toEqual({}); }); @@ -1398,7 +1233,9 @@ describe('Task-based execution', () => { const sentMessage = sendSpy.mock.calls[0][0]; expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(2); - expect(sentMessage.result.tasks).toEqual([{ taskId: 'task-3', status: 'working', keepAlive: null, pollInterval: 500 }]); + expect(sentMessage.result.tasks).toEqual([ + { taskId: 'task-3', status: 'working', ttl: null, createdAt: expect.any(String), pollInterval: 500 } + ]); expect(sentMessage.result.nextCursor).toBeUndefined(); expect(sentMessage.result._meta).toEqual({}); }); @@ -1485,13 +1322,9 @@ describe('Task-based execution', () => { jsonrpc: '2.0', id: sendSpy.mock.calls[0][0].id, result: { - tasks: [{ taskId: 'task-1', status: 'completed', keepAlive: null, pollInterval: 500 }], + tasks: [{ taskId: 'task-1', status: 'completed', ttl: null, createdAt: '2024-01-01T00:00:00Z', pollInterval: 500 }], nextCursor: undefined, - _meta: { - [TASK_META_KEY]: expect.objectContaining({ - taskId: expect.any(String) - }) - } + _meta: {} } }); }, 10); @@ -1520,13 +1353,11 @@ describe('Task-based execution', () => { jsonrpc: '2.0', id: sendSpy.mock.calls[0][0].id, result: { - tasks: [{ taskId: 'task-11', status: 'working', keepAlive: 30000, pollInterval: 1000 }], + tasks: [ + { taskId: 'task-11', status: 'working', ttl: 30000, createdAt: '2024-01-01T00:00:00Z', pollInterval: 1000 } + ], nextCursor: 'task-11', - _meta: { - [TASK_META_KEY]: expect.objectContaining({ - taskId: expect.any(String) - }) - } + _meta: {} } }); }, 10); @@ -1548,8 +1379,8 @@ describe('Task-based execution', () => { }); }); - describe('deleteTask', () => { - it('should handle tasks/delete requests and delete task from TaskStore', async () => { + describe('cancelTask', () => { + it('should handle tasks/cancel requests and update task status to cancelled', async () => { const taskDeleted = createLatch(); const mockTaskStore = createMockTaskStore(); await mockTaskStore.createTask( @@ -1563,8 +1394,8 @@ describe('Task-based execution', () => { } ); - mockTaskStore.deleteTask.mockImplementation(async (taskId: string) => { - if (taskId === 'task-to-delete') { + mockTaskStore.updateTaskStatus.mockImplementation(async (taskId: string, status: string) => { + if (taskId === 'task-to-delete' && status === 'cancelled') { taskDeleted.releaseLatch(); return; } @@ -1579,14 +1410,14 @@ describe('Task-based execution', () => { protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); - const sendSpy = jest.spyOn(serverTransport, 'send'); + const sendSpy = vi.spyOn(serverTransport, 'send'); await serverProtocol.connect(serverTransport); serverTransport.onmessage?.({ jsonrpc: '2.0', id: 5, - method: 'tasks/delete', + method: 'tasks/cancel', params: { taskId: 'task-to-delete' } @@ -1594,7 +1425,7 @@ describe('Task-based execution', () => { await taskDeleted.waitForLatch(); - expect(mockTaskStore.deleteTask).toHaveBeenCalledWith('task-to-delete', undefined); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('task-to-delete', 'cancelled', undefined, undefined); // eslint-disable-next-line @typescript-eslint/no-explicit-any const sentMessage = sendSpy.mock.calls[0][0] as any; expect(sentMessage.jsonrpc).toBe('2.0'); @@ -1606,7 +1437,7 @@ describe('Task-based execution', () => { const taskDeleted = createLatch(); const mockTaskStore = createMockTaskStore(); - mockTaskStore.deleteTask.mockImplementation(async () => { + mockTaskStore.updateTaskStatus.mockImplementation(async () => { taskDeleted.releaseLatch(); throw new Error('Task with ID non-existent not found'); }); @@ -1619,14 +1450,14 @@ describe('Task-based execution', () => { protected assertTaskHandlerCapability(): void {} })({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); - const sendSpy = jest.spyOn(serverTransport, 'send'); + const sendSpy = vi.spyOn(serverTransport, 'send'); await serverProtocol.connect(serverTransport); serverTransport.onmessage?.({ jsonrpc: '2.0', id: 6, - method: 'tasks/delete', + method: 'tasks/cancel', params: { taskId: 'non-existent' } @@ -1634,20 +1465,20 @@ describe('Task-based execution', () => { await taskDeleted.waitForLatch(); - expect(mockTaskStore.deleteTask).toHaveBeenCalledWith('non-existent', undefined); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('non-existent', 'cancelled', undefined, undefined); // eslint-disable-next-line @typescript-eslint/no-explicit-any const sentMessage = sendSpy.mock.calls[0][0] as any; expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(6); expect(sentMessage.error).toBeDefined(); expect(sentMessage.error.code).toBe(-32600); // InvalidRequest error code - expect(sentMessage.error.message).toContain('Failed to delete task'); + expect(sentMessage.error.message).toContain('Failed to cancel task'); }); - it('should call deleteTask method from client side', async () => { + it('should call cancelTask method from client side', async () => { await protocol.connect(transport); - const deleteTaskPromise = protocol.deleteTask({ taskId: 'task-to-delete' }); + const deleteTaskPromise = protocol.cancelTask({ taskId: 'task-to-delete' }); // Simulate server response setTimeout(() => { @@ -1664,7 +1495,7 @@ describe('Task-based execution', () => { expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ - method: 'tasks/delete', + method: 'tasks/cancel', params: { taskId: 'task-to-delete' } diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 91a16c5b9..5de8e6f6d 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -10,8 +10,8 @@ import { GetTaskPayloadRequestSchema, ListTasksRequestSchema, ListTasksResultSchema, - DeleteTaskRequestSchema, - DeleteTaskResultSchema, + CancelTaskRequestSchema, + CancelTaskResultSchema, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, @@ -34,9 +34,8 @@ import { RequestMeta, MessageExtraInfo, RequestInfo, - TaskCreatedNotificationSchema, - TASK_META_KEY, GetTaskResult, + TaskCreationParams, TaskMetadata, RelatedTaskMetadata, CancelledNotification, @@ -72,8 +71,8 @@ export type ProtocolOptions = { */ debouncedNotificationMethods?: string[]; /** - * Optional task storage implementation. If provided, the implementation will automatically - * handle task creation, status tracking, and result storage. + * Optional task storage implementation. If provided, enables task-related request handlers + * and provides task storage capabilities to request handlers. */ taskStore?: TaskStore; /** @@ -124,9 +123,9 @@ export type RequestOptions = { maxTotalTimeout?: number; /** - * If provided, augments the request with task metadata to enable call-now, fetch-later execution patterns. + * If provided, augments the request with task creation parameters to enable call-now, fetch-later execution patterns. */ - task?: TaskMetadata; + task?: TaskCreationParams; /** * If provided, associates this request with a related task. @@ -163,15 +162,17 @@ export interface RequestTaskStore { * Creates a new task with the given metadata and original request. * * @param task - The task creation metadata from the request - * @returns The task state including status, keepAlive, pollInterval, and optional error + * @param requestId - The JSON-RPC request ID + * @param request - The original request that triggered task creation + * @returns The task state including status, ttl, pollInterval, and optional statusMessage */ - createTask(task: TaskMetadata): Promise; + createTask(task: TaskMetadata, requestId: RequestId, request: Request): Promise; /** * Gets the current status of a task. * * @param taskId - The task identifier - * @returns The task state including status, keepAlive, pollInterval, and optional error + * @returns The task state including status, ttl, pollInterval, and optional statusMessage */ getTask(taskId: string): Promise; @@ -196,13 +197,9 @@ export interface RequestTaskStore { * * @param taskId - The task identifier * @param status - The new status - * @param error - Optional error message if status is 'failed' or 'cancelled' + * @param statusMessage - Optional diagnostic message for failed tasks or other status information */ - updateTaskStatus( - taskId: string, - status: Status, - error?: ErrorReason - ): Promise; + updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string): Promise; /** * Lists tasks, optionally starting from a pagination cursor. @@ -259,7 +256,7 @@ export type RequestHandlerExtra< taskStore?: RequestTaskStore; - taskRequestedKeepAlive?: number; + taskRequestedTtl?: number | null; /** * The original HTTP request. @@ -314,8 +311,7 @@ export abstract class Protocol = new Map(); private _timeoutInfo: Map = new Map(); private _pendingDebouncedNotifications = new Set(); - private _pendingTaskCreations: Map void; reject: (reason: unknown) => void }> = new Map(); - private _requestIdToTaskId: Map = new Map(); + private _taskStore?: TaskStore; /** @@ -354,17 +350,6 @@ export abstract class Protocol { - const taskId = notification.params?._meta?.[RELATED_TASK_META_KEY]?.taskId; - if (taskId) { - const resolver = this._pendingTaskCreations.get(taskId); - if (resolver) { - resolver.resolve(); - this._pendingTaskCreations.delete(taskId); - } - } - }); - this.setRequestHandler( PingRequestSchema, // Automatic pong by default. @@ -430,16 +415,16 @@ export abstract class Protocol { + this.setRequestHandler(CancelTaskRequestSchema, async (request, extra) => { try { - await this._taskStore!.deleteTask(request.params.taskId, extra.sessionId); + await this._taskStore!.updateTaskStatus(request.params.taskId, 'cancelled', undefined, extra.sessionId); return { _meta: {} } as SendResultT; } catch (error) { throw new McpError( ErrorCode.InvalidRequest, - `Failed to delete task: ${error instanceof Error ? error.message : String(error)}` + `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` ); } }); @@ -452,19 +437,6 @@ export abstract class Protocol { - // If this request had a task, mark it as cancelled in storage - const taskId = this._requestIdToTaskId.get(requestId); - const taskStore = this._taskStore ? this.requestTaskStore(undefined, sessionId) : undefined; - if (taskId && this._taskStore) { - try { - await taskStore?.updateTaskStatus(taskId, 'cancelled', undefined); - } catch (error) { - this._onerror(new Error(`Failed to cancel task ${taskId}: ${error}`)); - } - } - } - private _setupTimeout( messageId: number, timeout: number, @@ -508,16 +480,6 @@ export abstract class Protocol = { signal: abortController.signal, sessionId: capturedTransport?.sessionId, _meta: request.params?._meta, sendNotification: async notification => { - const relatedTask = taskMetadata ? { taskId: taskMetadata.taskId } : undefined; - await this.notification(notification, { relatedRequestId: request.id, relatedTask }); + await this.notification(notification, { relatedRequestId: request.id }); }, sendRequest: async (r, resultSchema, options?) => { - const relatedTask = taskMetadata ? { taskId: taskMetadata.taskId } : undefined; - if (taskMetadata) { - // Allow this to throw to the caller (request handler) - await taskStore?.updateTaskStatus(taskMetadata.taskId, 'input_required', undefined); - } - try { - return await this.request(r, resultSchema, { ...options, relatedRequestId: request.id, relatedTask }); - } finally { - if (taskMetadata) { - // Allow this to throw to the caller (request handler) - await taskStore?.updateTaskStatus(taskMetadata.taskId, 'working', undefined); - } - } + return await this.request(r, resultSchema, { ...options, relatedRequestId: request.id }); }, authInfo: extra?.authInfo, requestId: request.id, requestInfo: extra?.requestInfo, - taskId: taskMetadata?.taskId, + taskId: undefined, taskStore: taskStore, - taskRequestedKeepAlive: taskMetadata?.keepAlive + taskRequestedTtl: taskCreationParams?.ttl }; // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() .then(() => { // If this request asked for task creation, check capability first - if (taskMetadata) { + if (taskCreationParams) { // Check if the request method supports task creation this.assertTaskHandlerCapability(request.method); } @@ -665,7 +608,6 @@ export abstract class Protocol { if (abortController.signal.aborted) { // Request was cancelled - await this._postcancel(request.id, capturedTransport?.sessionId); return; } @@ -679,7 +621,6 @@ export abstract class Protocol { if (abortController.signal.aborted) { // Request was cancelled - await this._postcancel(request.id, capturedTransport?.sessionId); return; } @@ -716,8 +657,7 @@ export abstract class Protocol { const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {}; - const { taskId, keepAlive } = task ?? {}; - - // For tasks, create an advance promise for the creation notification to avoid - // race conditions with installing this callback. - const taskCreated = taskId ? this._waitForTaskCreation(taskId) : Promise.resolve(); // Send the request const result = new Promise>((resolve, reject) => { const earlyReject = (error: unknown) => { - // Clean up task tracking if we reject before sending - if (taskId) { - const resolver = this._pendingTaskCreations.get(taskId); - resolver?.reject(error); - this._pendingTaskCreations.delete(taskId); - } reject(error); }; @@ -844,8 +762,8 @@ export abstract class Protocol { - return new Promise((resolve, reject) => { - this._pendingTaskCreations.set(taskId, { resolve, reject }); - }); - } - /** * Gets the current status of a task. */ @@ -1007,11 +906,11 @@ export abstract class Protocol> { - // @ts-expect-error SendRequestT cannot directly contain DeleteTaskRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/delete', params }, DeleteTaskResultSchema, options); + async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise> { + // @ts-expect-error SendRequestT cannot directly contain CancelTaskRequest, but we ensure all type instantiations contain it anyways + return this.request({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); } /** @@ -1215,7 +1114,7 @@ export abstract class Protocol { return taskStore.getTaskResult(taskId, sessionId); }, - updateTaskStatus: async (taskId, status, errorReason) => { + updateTaskStatus: async (taskId, status, statusMessage) => { try { // Check the current task status to avoid overwriting terminal states // as a safeguard for when the TaskStore implementation doesn't try @@ -1232,7 +1131,7 @@ export abstract class Protocol { constructor( readonly protocol: Protocol, - readonly taskCreatedHandle: Promise, readonly resultHandle: Promise, readonly resultSchema: ZodType, readonly taskId?: string, @@ -33,24 +32,17 @@ export class PendingRequest { - // Start task handler immediately without waiting for creation notification - const taskPromise = this.taskHandler(this.taskId!, { + // Call onTaskCreated immediately since task is created synchronously by tool implementor + await onTaskCreated(); + + // Start task polling + return await this.taskHandler(this.taskId!, { onTaskCreated, onTaskStatus }); - - // Call onTaskCreated callback when notification arrives, but don't block taskHandler - // The promise is tied to the lifecycle of taskPromise, so it won't leak - this.taskCreatedHandle - .then(() => onTaskCreated()) - .catch(() => { - // Silently ignore if notification never arrives or fails - }); - - return await taskPromise; })(), this.resultHandle ]).then(([task, result]) => { @@ -61,7 +53,7 @@ export class PendingRequest; @@ -23,7 +23,7 @@ export interface TaskStore { * * @param taskId - The task identifier * @param sessionId - Optional session ID for binding the query to a specific session - * @returns The task state including status, keepAlive, pollInterval, and optional error + * @returns The task state including status, ttl, pollInterval, and optional statusMessage */ getTask(taskId: string, sessionId?: string): Promise; @@ -50,10 +50,10 @@ export interface TaskStore { * * @param taskId - The task identifier * @param status - The new status - * @param error - Optional error message if status is 'failed' or 'cancelled' + * @param statusMessage - Optional diagnostic message for failed tasks or other status information * @param sessionId - Optional session ID for binding the operation to a specific session */ - updateTaskStatus(taskId: string, status: Task['status'], error?: string, sessionId?: string): Promise; + updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string, sessionId?: string): Promise; /** * Lists tasks, optionally starting from a pagination cursor. diff --git a/src/types.ts b/src/types.ts index 75465eef4..87ec79b13 100644 --- a/src/types.ts +++ b/src/types.ts @@ -5,8 +5,7 @@ export const LATEST_PROTOCOL_VERSION = '2025-06-18'; export const DEFAULT_NEGOTIATED_PROTOCOL_VERSION = '2025-03-26'; export const SUPPORTED_PROTOCOL_VERSIONS = [LATEST_PROTOCOL_VERSION, '2025-03-26', '2024-11-05', '2024-10-07']; -export const TASK_META_KEY = 'modelcontextprotocol.io/task'; -export const RELATED_TASK_META_KEY = 'modelcontextprotocol.io/related-task'; +export const RELATED_TASK_META_KEY = 'io.modelcontextprotocol/related-task'; /* JSON-RPC types */ export const JSONRPC_VERSION = '2.0'; @@ -32,22 +31,19 @@ export const ProgressTokenSchema = z.union([z.string(), z.number().int()]); export const CursorSchema = z.string(); /** - * Task creation metadata, used to ask that the server create a task to represent a request. + * Task creation parameters, used to ask that the server create a task to represent a request. + * The taskId is generated by the receiver, not provided by the requestor. */ -export const TaskMetadataSchema = z +export const TaskCreationParamsSchema = z .object({ /** - * The task ID to use as a reference to the created task. + * Time in milliseconds to keep task results available after completion. + * If null, the task has unlimited lifetime until manually cleaned up. */ - taskId: z.string(), + ttl: z.union([z.number(), z.null()]).optional(), /** - * Time in milliseconds to ask to keep task results available after completion. Only used with taskId. - */ - keepAlive: z.number().optional(), - - /** - * Time in milliseconds to wait between task status requests. Only used with taskId. + * Time in milliseconds to wait between task status requests. */ pollInterval: z.optional(z.number()) }) @@ -56,6 +52,13 @@ export const TaskMetadataSchema = z */ .passthrough(); +/** + * @deprecated Use TaskCreationParamsSchema instead. This is kept for backward compatibility. + */ +export const TaskMetadataSchema = TaskCreationParamsSchema.extend({ + taskId: z.string() +}); + /** * Task association metadata, used to signal which task a message originated from. */ @@ -71,10 +74,6 @@ const RequestMetaSchema = z * If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. */ progressToken: z.optional(ProgressTokenSchema), - /** - * If specified, the caller is requesting that the receiver create a task to represent the request. - */ - [TASK_META_KEY]: z.optional(TaskMetadataSchema), /** * If specified, this request is related to the provided task. */ @@ -89,6 +88,11 @@ const RequestMetaSchema = z * Common params for any request. */ const BaseRequestParamsSchema = z.object({ + /** + * If specified, the caller is requesting that the receiver create a task to represent the request. + * Task creation parameters are now at the top level instead of in _meta. + */ + task: TaskCreationParamsSchema.optional(), /** * See [General fields: `_meta`](/specification/draft/basic/index#meta) for notes on `_meta` usage. */ @@ -337,6 +341,14 @@ export const ImplementationSchema = BaseMetadataSchema.extend({ */ export const ClientTasksCapabilitySchema = z .object({ + /** + * Present if the client supports listing tasks. + */ + list: z.optional(z.object({}).passthrough()), + /** + * Present if the client supports cancelling tasks. + */ + cancel: z.optional(z.object({}).passthrough()), /** * Capabilities for task creation on specific request types. */ @@ -349,7 +361,7 @@ export const ClientTasksCapabilitySchema = z sampling: z.optional( z .object({ - createMessage: z.optional(z.boolean()) + createMessage: z.optional(z.object({}).passthrough()) }) .passthrough() ), @@ -359,7 +371,7 @@ export const ClientTasksCapabilitySchema = z elicitation: z.optional( z .object({ - create: z.optional(z.boolean()) + create: z.optional(z.object({}).passthrough()) }) .passthrough() ) @@ -374,6 +386,14 @@ export const ClientTasksCapabilitySchema = z */ export const ServerTasksCapabilitySchema = z .object({ + /** + * Present if the server supports listing tasks. + */ + list: z.optional(z.object({}).passthrough()), + /** + * Present if the server supports cancelling tasks. + */ + cancel: z.optional(z.object({}).passthrough()), /** * Capabilities for task creation on specific request types. */ @@ -386,7 +406,7 @@ export const ServerTasksCapabilitySchema = z tools: z.optional( z .object({ - call: z.optional(z.boolean()) + call: z.optional(z.object({}).passthrough()) }) .passthrough() ) @@ -615,12 +635,28 @@ export const PaginatedResultSchema = ResultSchema.extend({ export const TaskSchema = z.object({ taskId: z.string(), status: z.enum(['working', 'input_required', 'completed', 'failed', 'cancelled']), - keepAlive: z.union([z.number(), z.null()]), + /** + * Time in milliseconds to keep task results available after completion. + * If null, the task has unlimited lifetime until manually cleaned up. + */ + ttl: z.union([z.number(), z.null()]), + /** + * ISO 8601 timestamp when the task was created. + */ + createdAt: z.string(), pollInterval: z.optional(z.number()), - error: z.optional(z.string()) + /** + * Optional diagnostic message for failed tasks or other status information. + */ + statusMessage: z.optional(z.string()) }); -export const CreateTaskResultSchema = ResultSchema.merge(TaskSchema); +/** + * Result returned when a task is created, containing the task data wrapped in a task field. + */ +export const CreateTaskResultSchema = ResultSchema.extend({ + task: TaskSchema +}); /** * An out-of-band notification used to inform the receiver of a task being created. @@ -629,6 +665,21 @@ export const TaskCreatedNotificationSchema = NotificationSchema.extend({ method: z.literal('notifications/tasks/created') }); +/** + * Parameters for task status notification. + */ +export const TaskStatusNotificationParamsSchema = z.object({ + task: TaskSchema +}); + +/** + * A notification sent when a task's status changes. + */ +export const TaskStatusNotificationSchema = NotificationSchema.extend({ + method: z.literal('notifications/tasks/status'), + params: TaskStatusNotificationParamsSchema +}); + /** * A request to get the state of a specific task. */ @@ -669,19 +720,19 @@ export const ListTasksResultSchema = PaginatedResultSchema.extend({ }); /** - * A request to delete a specific task. + * A request to cancel a specific task. */ -export const DeleteTaskRequestSchema = RequestSchema.extend({ - method: z.literal('tasks/delete'), +export const CancelTaskRequestSchema = RequestSchema.extend({ + method: z.literal('tasks/cancel'), params: BaseRequestParamsSchema.extend({ taskId: z.string() }) }); /** - * The response to a tasks/delete request. + * The response to a tasks/cancel request. */ -export const DeleteTaskResultSchema = ResultSchema; +export const CancelTaskResultSchema = ResultSchema; /* Resources */ /** @@ -1138,13 +1189,14 @@ export const ToolAnnotationsSchema = z.object({ openWorldHint: z.boolean().optional(), /** - * If true, this tool is expected to support task-augmented execution. - * This allows clients to handle long-running operations through polling - * the task system. + * Indicates the tool's preference for task-augmented execution. + * - "always": Clients SHOULD invoke the tool as a task + * - "optional": Clients MAY invoke the tool as a task or normal request + * - "never": Clients SHALL NOT attempt to invoke the tool as a task * - * Default: false + * If not present, defaults to "never". */ - taskHint: z.boolean().optional() + taskHint: z.enum(['always', 'optional', 'never']).optional() }); /** @@ -1915,17 +1967,20 @@ export type ProgressNotification = Infer; /* Tasks */ export type Task = Infer; +export type TaskCreationParams = Infer; export type TaskMetadata = Infer; export type RelatedTaskMetadata = Infer; export type CreateTaskResult = Infer; export type TaskCreatedNotification = Infer; +export type TaskStatusNotificationParams = Infer; +export type TaskStatusNotification = Infer; export type GetTaskRequest = Infer; export type GetTaskResult = Infer; export type GetTaskPayloadRequest = Infer; export type ListTasksRequest = Infer; export type ListTasksResult = Infer; -export type DeleteTaskRequest = Infer; -export type DeleteTaskResult = Infer; +export type CancelTaskRequest = Infer; +export type CancelTaskResult = Infer; /* Pagination */ export type PaginatedRequestParams = Infer; From 0af275f290d03f9a2ebc51265670020858add915 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 18 Nov 2025 14:20:54 -0800 Subject: [PATCH 33/84] Implement tasks/cancel --- src/client/index.ts | 2 +- src/examples/shared/inMemoryTaskStore.test.ts | 54 ------------ src/examples/shared/inMemoryTaskStore.ts | 17 ---- src/server/index.ts | 2 +- src/shared/protocol.test.ts | 87 +++++++++++++++---- src/shared/protocol.ts | 34 +++++--- src/shared/task.ts | 9 -- 7 files changed, 93 insertions(+), 112 deletions(-) diff --git a/src/client/index.ts b/src/client/index.ts index e5dfe61b8..6107baf38 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -424,7 +424,7 @@ export class Client< case 'tasks/get': case 'tasks/list': case 'tasks/result': - case 'tasks/delete': + case 'tasks/cancel': if (!this._capabilities.tasks) { throw new Error(`Client does not support tasks capability (required for ${method})`); } diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts index 1e61d16b2..7b682033d 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -443,58 +443,4 @@ describe('InMemoryTaskStore', () => { expect(store.getAllTasks()).toHaveLength(0); }); }); - - describe('deleteTask', () => { - it('should delete an existing task', async () => { - await store.createTask({ taskId: 'task-to-delete' }, 1, { - method: 'tools/call', - params: {} - }); - - expect(await store.getTask('task-to-delete')).toBeDefined(); - - await store.deleteTask('task-to-delete'); - - expect(await store.getTask('task-to-delete')).toBeNull(); - }); - - it('should throw error when deleting non-existent task', async () => { - await expect(store.deleteTask('non-existent')).rejects.toThrow('Task with ID non-existent not found'); - }); - - it('should clear cleanup timer when deleting task with ttl', async () => { - vi.useFakeTimers(); - - await store.createTask({ taskId: 'task-with-timer', ttl: 1000 }, 1, { - method: 'tools/call', - params: {} - }); - - expect(await store.getTask('task-with-timer')).toBeDefined(); - - await store.deleteTask('task-with-timer'); - - // Fast-forward past ttl time - vi.advanceTimersByTime(1001); - - // Task should not exist (it was deleted immediately, not cleaned up by timer) - expect(await store.getTask('task-with-timer')).toBeNull(); - - vi.useRealTimers(); - }); - - it('should delete task with result', async () => { - await store.createTask({ taskId: 'task-with-result' }, 1, { - method: 'tools/call', - params: {} - }); - - const result = { content: [{ type: 'text' as const, text: 'Result' }] }; - await store.storeTaskResult('task-with-result', result); - - await store.deleteTask('task-with-result'); - - expect(await store.getTask('task-with-result')).toBeNull(); - }); - }); }); diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index c702f0e9d..fa97d0957 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -151,23 +151,6 @@ export class InMemoryTaskStore implements TaskStore { return { tasks, nextCursor }; } - async deleteTask(taskId: string, _sessionId?: string): Promise { - const stored = this.tasks.get(taskId); - if (!stored) { - throw new Error(`Task with ID ${taskId} not found`); - } - - // Clear any associated cleanup timer - const existingTimer = this.cleanupTimers.get(taskId); - if (existingTimer) { - clearTimeout(existingTimer); - this.cleanupTimers.delete(taskId); - } - - // Delete the task - this.tasks.delete(taskId); - } - /** * Cleanup all timers (useful for testing or graceful shutdown) */ diff --git a/src/server/index.ts b/src/server/index.ts index 6a0d5c8a9..74d277401 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -280,7 +280,7 @@ export class Server< case 'tasks/get': case 'tasks/list': case 'tasks/result': - case 'tasks/delete': + case 'tasks/cancel': if (!this._capabilities.tasks) { throw new Error(`Server does not support tasks capability (required for ${method})`); } diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index eb0a6f367..f777e3fed 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -82,13 +82,6 @@ function createMockTaskStore(options?: { }; options?.onList?.(); return Promise.resolve(result); - }), - deleteTask: vi.fn((taskId: string) => { - if (tasks[taskId]) { - delete tasks[taskId]; - return Promise.resolve(); - } - return Promise.reject(new Error(`Task with ID ${taskId} not found`)); }) }; } @@ -1383,7 +1376,7 @@ describe('Task-based execution', () => { it('should handle tasks/cancel requests and update task status to cancelled', async () => { const taskDeleted = createLatch(); const mockTaskStore = createMockTaskStore(); - await mockTaskStore.createTask( + const task = await mockTaskStore.createTask( { taskId: 'task-to-delete' }, @@ -1394,6 +1387,7 @@ describe('Task-based execution', () => { } ); + mockTaskStore.getTask.mockResolvedValue(task); mockTaskStore.updateTaskStatus.mockImplementation(async (taskId: string, status: string) => { if (taskId === 'task-to-delete' && status === 'cancelled') { taskDeleted.releaseLatch(); @@ -1425,7 +1419,13 @@ describe('Task-based execution', () => { await taskDeleted.waitForLatch(); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('task-to-delete', 'cancelled', undefined, undefined); + expect(mockTaskStore.getTask).toHaveBeenCalledWith('task-to-delete', undefined); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith( + 'task-to-delete', + 'cancelled', + 'Client cancelled task execution.', + undefined + ); // eslint-disable-next-line @typescript-eslint/no-explicit-any const sentMessage = sendSpy.mock.calls[0][0] as any; expect(sentMessage.jsonrpc).toBe('2.0'); @@ -1433,14 +1433,11 @@ describe('Task-based execution', () => { expect(sentMessage.result._meta).toBeDefined(); }); - it('should return error with code -32600 when task does not exist', async () => { + it('should return error with code -32602 when task does not exist', async () => { const taskDeleted = createLatch(); const mockTaskStore = createMockTaskStore(); - mockTaskStore.updateTaskStatus.mockImplementation(async () => { - taskDeleted.releaseLatch(); - throw new Error('Task with ID non-existent not found'); - }); + mockTaskStore.getTask.mockResolvedValue(null); const serverProtocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1463,16 +1460,70 @@ describe('Task-based execution', () => { } }); - await taskDeleted.waitForLatch(); + // Wait a bit for the async handler to complete + await new Promise(resolve => setTimeout(resolve, 10)); + taskDeleted.releaseLatch(); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('non-existent', 'cancelled', undefined, undefined); + expect(mockTaskStore.getTask).toHaveBeenCalledWith('non-existent', undefined); // eslint-disable-next-line @typescript-eslint/no-explicit-any const sentMessage = sendSpy.mock.calls[0][0] as any; expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(6); expect(sentMessage.error).toBeDefined(); - expect(sentMessage.error.code).toBe(-32600); // InvalidRequest error code - expect(sentMessage.error.message).toContain('Failed to cancel task'); + expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.message).toContain('Task not found'); + }); + + it('should return error with code -32602 when trying to cancel a task in terminal status', async () => { + const mockTaskStore = createMockTaskStore(); + const completedTask = await mockTaskStore.createTask( + { + taskId: 'completed-task' + }, + 1, + { + method: 'test/method', + params: {} + } + ); + // Set task to completed status + completedTask.status = 'completed'; + + mockTaskStore.getTask.mockResolvedValue(completedTask); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 7, + method: 'tasks/cancel', + params: { + taskId: 'completed-task' + } + }); + + // Wait a bit for the async handler to complete + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.getTask).toHaveBeenCalledWith('completed-task', undefined); + expect(mockTaskStore.updateTaskStatus).not.toHaveBeenCalled(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const sentMessage = sendSpy.mock.calls[0][0] as any; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(7); + expect(sentMessage.error).toBeDefined(); + expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.message).toContain('Cannot cancel task in terminal status'); }); it('should call cancelTask method from client side', async () => { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 5de8e6f6d..4d6fb0215 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -208,14 +208,6 @@ export interface RequestTaskStore { * @returns An object containing the tasks array and an optional nextCursor */ listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; - - /** - * Deletes a specific task and its associated data. - * - * @param taskId - The task identifier - * @throws Error if the task doesn't exist or cannot be deleted - */ - deleteTask(taskId: string): Promise; } /** @@ -417,11 +409,32 @@ export abstract class Protocol { try { - await this._taskStore!.updateTaskStatus(request.params.taskId, 'cancelled', undefined, extra.sessionId); + // Get the current task to check if it's in a terminal state, in case the implementation is not atomic + const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); + + if (!task) { + throw new McpError(ErrorCode.InvalidParams, `Task not found: ${request.params.taskId}`); + } + + // Reject cancellation of terminal tasks + if (isTerminal(task.status)) { + throw new McpError(ErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); + } + + await this._taskStore!.updateTaskStatus( + request.params.taskId, + 'cancelled', + 'Client cancelled task execution.', + extra.sessionId + ); return { _meta: {} } as SendResultT; } catch (error) { + // Re-throw McpError as-is + if (error instanceof McpError) { + throw error; + } throw new McpError( ErrorCode.InvalidRequest, `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` @@ -1138,9 +1151,6 @@ export abstract class Protocol { return taskStore.listTasks(cursor, sessionId); - }, - deleteTask: taskId => { - return taskStore.deleteTask(taskId, sessionId); } }; } diff --git a/src/shared/task.ts b/src/shared/task.ts index b9e328ee7..e3d30e025 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -63,15 +63,6 @@ export interface TaskStore { * @returns An object containing the tasks array and an optional nextCursor */ listTasks(cursor?: string, sessionId?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; - - /** - * Deletes a specific task and its associated data. - * - * @param taskId - The task identifier - * @param sessionId - Optional session ID for binding the operation to a specific session - * @throws Error if the task doesn't exist or cannot be deleted - */ - deleteTask(taskId: string, sessionId?: string): Promise; } /** From 25d0b14d5ae3def4d38c816ab3a145637775cba8 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 18 Nov 2025 15:04:18 -0800 Subject: [PATCH 34/84] Update TaskStore interface --- README.md | 10 +- src/client/index.test.ts | 57 +++--- src/examples/shared/inMemoryTaskStore.test.ts | 183 +++++++++--------- src/examples/shared/inMemoryTaskStore.ts | 28 ++- src/server/index.test.ts | 105 +++++----- src/shared/protocol.test.ts | 81 ++++---- src/shared/protocol.ts | 22 +-- src/shared/task.ts | 11 +- src/types.ts | 8 - 9 files changed, 240 insertions(+), 265 deletions(-) diff --git a/README.md b/README.md index 47588d600..fae06a772 100644 --- a/README.md +++ b/README.md @@ -1324,15 +1324,17 @@ import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprot // Implement TaskStore backed by your database (e.g., PostgreSQL, Redis, etc.) class MyTaskStore implements TaskStore { - async createTask(metadata, requestId, request) { + async createTask(taskParams, requestId, request) { + // Generate unique taskId and createdAt timestamp // Store task in your database + // Return Task object with generated taskId } async getTask(taskId) { // Retrieve task from your database } - async updateTaskStatus(taskId, status, errorMessage?) { + async updateTaskStatus(taskId, status, statusMessage?) { // Update task status in your database } @@ -1343,6 +1345,10 @@ class MyTaskStore implements TaskStore { async getTaskResult(taskId) { // Retrieve task result from your database } + + async listTasks(cursor?, sessionId?) { + // List tasks with pagination support + } } const taskStore = new MyTaskStore(); diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 71d5a36d5..071ff7274 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -1367,15 +1367,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( + const createdTask = await extra.taskStore.createTask( { - taskId, ttl: extra.taskRequestedTtl }, extra.requestId, request ); + taskId = createdTask.taskId; } if (request.params.name === 'test-tool') { @@ -1454,15 +1453,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( + const createdTask = await extra.taskStore.createTask( { - taskId, ttl: extra.taskRequestedTtl }, extra.requestId, request ); + taskId = createdTask.taskId; } if (request.params.name === 'test-tool') { @@ -1541,15 +1539,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( + const createdTask = await extra.taskStore.createTask( { - taskId, ttl: extra.taskRequestedTtl }, extra.requestId, request ); + taskId = createdTask.taskId; } if (request.params.name === 'test-tool') { @@ -1626,15 +1623,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( + const createdTask = await extra.taskStore.createTask( { - taskId, ttl: extra.taskRequestedTtl }, extra.requestId, request ); + taskId = createdTask.taskId; } if (request.params.name === 'test-tool') { const result = { @@ -1738,15 +1734,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( + const createdTask = await extra.taskStore.createTask( { - taskId, ttl: extra.taskRequestedTtl }, extra.requestId, request ); + taskId = createdTask.taskId; } const result = { action: 'accept', @@ -1842,15 +1837,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( + const createdTask = await extra.taskStore.createTask( { - taskId, ttl: extra.taskRequestedTtl }, extra.requestId, request ); + taskId = createdTask.taskId; } const result = { action: 'accept', @@ -1944,15 +1938,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( + const createdTask = await extra.taskStore.createTask( { - taskId, ttl: extra.taskRequestedTtl }, extra.requestId, request ); + taskId = createdTask.taskId; } const result = { action: 'accept', @@ -2045,15 +2038,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( + const createdTask = await extra.taskStore.createTask( { - taskId, ttl: extra.taskRequestedTtl }, extra.requestId, request ); + taskId = createdTask.taskId; } const result = { action: 'accept', @@ -2161,15 +2153,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( - { - taskId, - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); + taskId = createdTask.taskId; } if (request.params.name === 'test-tool') { const result = { diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts index 7b682033d..c994d54d9 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -1,6 +1,6 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; import { InMemoryTaskStore } from './inMemoryTaskStore.js'; -import { TaskMetadata, Request } from '../../types.js'; +import { TaskCreationParams, Request } from '../../types.js'; describe('InMemoryTaskStore', () => { let store: InMemoryTaskStore; @@ -15,8 +15,7 @@ describe('InMemoryTaskStore', () => { describe('createTask', () => { it('should create a new task with working status', async () => { - const metadata: TaskMetadata = { - taskId: 'task-1', + const taskParams: TaskCreationParams = { ttl: 60000 }; const request: Request = { @@ -24,44 +23,43 @@ describe('InMemoryTaskStore', () => { params: { name: 'test-tool' } }; - await store.createTask(metadata, 123, request); + const task = await store.createTask(taskParams, 123, request); - const task = await store.getTask('task-1'); expect(task).toBeDefined(); - expect(task?.taskId).toBe('task-1'); - expect(task?.status).toBe('working'); - expect(task?.ttl).toBe(60000); - expect(task?.pollInterval).toBe(500); + expect(task.taskId).toBeDefined(); + expect(typeof task.taskId).toBe('string'); + expect(task.taskId.length).toBeGreaterThan(0); + expect(task.status).toBe('working'); + expect(task.ttl).toBe(60000); + expect(task.pollInterval).toBe(500); + expect(task.createdAt).toBeDefined(); + expect(new Date(task.createdAt).getTime()).toBeGreaterThan(0); }); it('should create task without ttl', async () => { - const metadata: TaskMetadata = { - taskId: 'task-no-keepalive' - }; + const taskParams: TaskCreationParams = {}; const request: Request = { method: 'tools/call', params: {} }; - await store.createTask(metadata, 456, request); + const task = await store.createTask(taskParams, 456, request); - const task = await store.getTask('task-no-keepalive'); expect(task).toBeDefined(); - expect(task?.ttl).toBeNull(); + expect(task.ttl).toBeNull(); }); - it('should reject duplicate taskId', async () => { - const metadata: TaskMetadata = { - taskId: 'duplicate-task' - }; + it('should generate unique taskIds', async () => { + const taskParams: TaskCreationParams = {}; const request: Request = { method: 'tools/call', params: {} }; - await store.createTask(metadata, 789, request); + const task1 = await store.createTask(taskParams, 789, request); + const task2 = await store.createTask(taskParams, 790, request); - await expect(store.createTask(metadata, 790, request)).rejects.toThrow('Task with ID duplicate-task already exists'); + expect(task1.taskId).not.toBe(task2.taskId); }); }); @@ -72,65 +70,64 @@ describe('InMemoryTaskStore', () => { }); it('should return task state', async () => { - const metadata: TaskMetadata = { - taskId: 'get-test' - }; + const taskParams: TaskCreationParams = {}; const request: Request = { method: 'tools/call', params: {} }; - await store.createTask(metadata, 111, request); - await store.updateTaskStatus('get-test', 'working'); + const createdTask = await store.createTask(taskParams, 111, request); + await store.updateTaskStatus(createdTask.taskId, 'working'); - const task = await store.getTask('get-test'); + const task = await store.getTask(createdTask.taskId); expect(task).toBeDefined(); expect(task?.status).toBe('working'); }); }); describe('updateTaskStatus', () => { + let taskId: string; + beforeEach(async () => { - const metadata: TaskMetadata = { - taskId: 'status-test' - }; - await store.createTask(metadata, 222, { + const taskParams: TaskCreationParams = {}; + const createdTask = await store.createTask(taskParams, 222, { method: 'tools/call', params: {} }); + taskId = createdTask.taskId; }); it('should keep task status as working', async () => { - const task = await store.getTask('status-test'); + const task = await store.getTask(taskId); expect(task?.status).toBe('working'); }); it('should update task status to input_required', async () => { - await store.updateTaskStatus('status-test', 'input_required'); + await store.updateTaskStatus(taskId, 'input_required'); - const task = await store.getTask('status-test'); + const task = await store.getTask(taskId); expect(task?.status).toBe('input_required'); }); it('should update task status to completed', async () => { - await store.updateTaskStatus('status-test', 'completed'); + await store.updateTaskStatus(taskId, 'completed'); - const task = await store.getTask('status-test'); + const task = await store.getTask(taskId); expect(task?.status).toBe('completed'); }); it('should update task status to failed with error', async () => { - await store.updateTaskStatus('status-test', 'failed', 'Something went wrong'); + await store.updateTaskStatus(taskId, 'failed', 'Something went wrong'); - const task = await store.getTask('status-test'); + const task = await store.getTask(taskId); expect(task?.status).toBe('failed'); expect(task?.statusMessage).toBe('Something went wrong'); }); it('should update task status to cancelled', async () => { - await store.updateTaskStatus('status-test', 'cancelled'); + await store.updateTaskStatus(taskId, 'cancelled'); - const task = await store.getTask('status-test'); + const task = await store.getTask(taskId); expect(task?.status).toBe('cancelled'); }); @@ -140,15 +137,17 @@ describe('InMemoryTaskStore', () => { }); describe('storeTaskResult', () => { + let taskId: string; + beforeEach(async () => { - const metadata: TaskMetadata = { - taskId: 'result-test', + const taskParams: TaskCreationParams = { ttl: 60000 }; - await store.createTask(metadata, 333, { + const createdTask = await store.createTask(taskParams, 333, { method: 'tools/call', params: {} }); + taskId = createdTask.taskId; }); it('should store task result and set status to completed', async () => { @@ -156,12 +155,12 @@ describe('InMemoryTaskStore', () => { content: [{ type: 'text' as const, text: 'Success!' }] }; - await store.storeTaskResult('result-test', result); + await store.storeTaskResult(taskId, result); - const task = await store.getTask('result-test'); + const task = await store.getTask(taskId); expect(task?.status).toBe('completed'); - const storedResult = await store.getTaskResult('result-test'); + const storedResult = await store.getTaskResult(taskId); expect(storedResult).toEqual(result); }); @@ -176,22 +175,18 @@ describe('InMemoryTaskStore', () => { }); it('should throw if task has no result stored', async () => { - const metadata: TaskMetadata = { - taskId: 'no-result' - }; - await store.createTask(metadata, 444, { + const taskParams: TaskCreationParams = {}; + const createdTask = await store.createTask(taskParams, 444, { method: 'tools/call', params: {} }); - await expect(store.getTaskResult('no-result')).rejects.toThrow('Task no-result has no result stored'); + await expect(store.getTaskResult(createdTask.taskId)).rejects.toThrow(`Task ${createdTask.taskId} has no result stored`); }); it('should return stored result', async () => { - const metadata: TaskMetadata = { - taskId: 'with-result' - }; - await store.createTask(metadata, 555, { + const taskParams: TaskCreationParams = {}; + const createdTask = await store.createTask(taskParams, 555, { method: 'tools/call', params: {} }); @@ -199,9 +194,9 @@ describe('InMemoryTaskStore', () => { const result = { content: [{ type: 'text' as const, text: 'Result data' }] }; - await store.storeTaskResult('with-result', result); + await store.storeTaskResult(createdTask.taskId, result); - const retrieved = await store.getTaskResult('with-result'); + const retrieved = await store.getTaskResult(createdTask.taskId); expect(retrieved).toEqual(result); }); }); @@ -216,33 +211,31 @@ describe('InMemoryTaskStore', () => { }); it('should cleanup task after ttl duration', async () => { - const metadata: TaskMetadata = { - taskId: 'cleanup-test', + const taskParams: TaskCreationParams = { ttl: 1000 }; - await store.createTask(metadata, 666, { + const createdTask = await store.createTask(taskParams, 666, { method: 'tools/call', params: {} }); // Task should exist initially - let task = await store.getTask('cleanup-test'); + let task = await store.getTask(createdTask.taskId); expect(task).toBeDefined(); // Fast-forward past ttl vi.advanceTimersByTime(1001); // Task should be cleaned up - task = await store.getTask('cleanup-test'); + task = await store.getTask(createdTask.taskId); expect(task).toBeNull(); }); it('should reset cleanup timer when result is stored', async () => { - const metadata: TaskMetadata = { - taskId: 'reset-cleanup', + const taskParams: TaskCreationParams = { ttl: 1000 }; - await store.createTask(metadata, 777, { + const createdTask = await store.createTask(taskParams, 777, { method: 'tools/call', params: {} }); @@ -251,7 +244,7 @@ describe('InMemoryTaskStore', () => { vi.advanceTimersByTime(500); // Store result (should reset timer) - await store.storeTaskResult('reset-cleanup', { + await store.storeTaskResult(createdTask.taskId, { content: [{ type: 'text' as const, text: 'Done' }] }); @@ -259,22 +252,20 @@ describe('InMemoryTaskStore', () => { vi.advanceTimersByTime(500); // Task should still exist - const task = await store.getTask('reset-cleanup'); + const task = await store.getTask(createdTask.taskId); expect(task).toBeDefined(); // Fast-forward remaining time vi.advanceTimersByTime(501); // Now task should be cleaned up - const cleanedTask = await store.getTask('reset-cleanup'); + const cleanedTask = await store.getTask(createdTask.taskId); expect(cleanedTask).toBeNull(); }); it('should not cleanup tasks without ttl', async () => { - const metadata: TaskMetadata = { - taskId: 'no-cleanup' - }; - await store.createTask(metadata, 888, { + const taskParams: TaskCreationParams = {}; + const createdTask = await store.createTask(taskParams, 888, { method: 'tools/call', params: {} }); @@ -283,16 +274,15 @@ describe('InMemoryTaskStore', () => { vi.advanceTimersByTime(100000); // Task should still exist - const task = await store.getTask('no-cleanup'); + const task = await store.getTask(createdTask.taskId); expect(task).toBeDefined(); }); it('should start cleanup timer when task reaches terminal state', async () => { - const metadata: TaskMetadata = { - taskId: 'terminal-cleanup', + const taskParams: TaskCreationParams = { ttl: 1000 }; - await store.createTask(metadata, 999, { + const createdTask = await store.createTask(taskParams, 999, { method: 'tools/call', params: {} }); @@ -301,49 +291,50 @@ describe('InMemoryTaskStore', () => { vi.advanceTimersByTime(1001); // Task should be cleaned up - let task = await store.getTask('terminal-cleanup'); + let task = await store.getTask(createdTask.taskId); expect(task).toBeNull(); // Create another task - const metadata2: TaskMetadata = { - taskId: 'terminal-cleanup-2', + const taskParams2: TaskCreationParams = { ttl: 2000 }; - await store.createTask(metadata2, 1000, { + const createdTask2 = await store.createTask(taskParams2, 1000, { method: 'tools/call', params: {} }); // Update to terminal state - await store.updateTaskStatus('terminal-cleanup-2', 'completed'); + await store.updateTaskStatus(createdTask2.taskId, 'completed'); // Fast-forward past original ttl vi.advanceTimersByTime(2001); // Task should be cleaned up - task = await store.getTask('terminal-cleanup-2'); + task = await store.getTask(createdTask2.taskId); expect(task).toBeNull(); }); }); describe('getAllTasks', () => { it('should return all tasks', async () => { - await store.createTask({ taskId: 'task-1' }, 1, { + await store.createTask({}, 1, { method: 'tools/call', params: {} }); - await store.createTask({ taskId: 'task-2' }, 2, { + await store.createTask({}, 2, { method: 'tools/call', params: {} }); - await store.createTask({ taskId: 'task-3' }, 3, { + await store.createTask({}, 3, { method: 'tools/call', params: {} }); const tasks = store.getAllTasks(); expect(tasks).toHaveLength(3); - expect(tasks.map(t => t.taskId).sort()).toEqual(['task-1', 'task-2', 'task-3']); + // Verify all tasks have unique IDs + const taskIds = tasks.map(t => t.taskId); + expect(new Set(taskIds).size).toBe(3); }); it('should return empty array when no tasks', () => { @@ -360,15 +351,15 @@ describe('InMemoryTaskStore', () => { }); it('should return all tasks when less than page size', async () => { - await store.createTask({ taskId: 'task-1' }, 1, { + await store.createTask({}, 1, { method: 'tools/call', params: {} }); - await store.createTask({ taskId: 'task-2' }, 2, { + await store.createTask({}, 2, { method: 'tools/call', params: {} }); - await store.createTask({ taskId: 'task-3' }, 3, { + await store.createTask({}, 3, { method: 'tools/call', params: {} }); @@ -381,7 +372,7 @@ describe('InMemoryTaskStore', () => { it('should paginate when more than page size', async () => { // Create 15 tasks (page size is 10) for (let i = 1; i <= 15; i++) { - await store.createTask({ taskId: `task-${i}` }, i, { + await store.createTask({}, i, { method: 'tools/call', params: {} }); @@ -399,7 +390,7 @@ describe('InMemoryTaskStore', () => { }); it('should throw error for invalid cursor', async () => { - await store.createTask({ taskId: 'task-1' }, 1, { + await store.createTask({}, 1, { method: 'tools/call', params: {} }); @@ -408,9 +399,9 @@ describe('InMemoryTaskStore', () => { }); it('should continue from cursor correctly', async () => { - // Create tasks with predictable IDs + // Create 5 tasks for (let i = 1; i <= 5; i++) { - await store.createTask({ taskId: `task-${i}` }, i, { + await store.createTask({}, i, { method: 'tools/call', params: {} }); @@ -420,18 +411,18 @@ describe('InMemoryTaskStore', () => { const allTaskIds = Array.from(store.getAllTasks().map(t => t.taskId)); const result = await store.listTasks(allTaskIds[2]); - // Should get tasks after task-3 + // Should get tasks after the third task expect(result.tasks).toHaveLength(2); }); }); describe('cleanup', () => { it('should clear all timers and tasks', async () => { - await store.createTask({ taskId: 'task-1', ttl: 1000 }, 1, { + await store.createTask({ ttl: 1000 }, 1, { method: 'tools/call', params: {} }); - await store.createTask({ taskId: 'task-2', ttl: 2000 }, 2, { + await store.createTask({ ttl: 2000 }, 2, { method: 'tools/call', params: {} }); diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index fa97d0957..b179e286a 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -1,5 +1,6 @@ -import { Task, TaskMetadata, Request, RequestId, Result } from '../../types.js'; +import { Task, TaskCreationParams, Request, RequestId, Result } from '../../types.js'; import { TaskStore, isTerminal } from '../../shared/task.js'; +import { randomBytes } from 'crypto'; interface StoredTask { task: Task; @@ -12,7 +13,7 @@ interface StoredTask { * A simple in-memory implementation of TaskStore for demonstration purposes. * * This implementation stores all tasks in memory and provides automatic cleanup - * based on the ttl duration specified in the task metadata. + * based on the ttl duration specified in the task creation parameters. * * Note: This is not suitable for production use as all data is lost on restart. * For production, consider implementing TaskStore with a database or distributed cache. @@ -21,19 +22,30 @@ export class InMemoryTaskStore implements TaskStore { private tasks = new Map(); private cleanupTimers = new Map>(); - async createTask(metadata: TaskMetadata, requestId: RequestId, request: Request, _sessionId?: string): Promise { - const taskId = metadata.taskId; + /** + * Generates a unique task ID. + * Uses 16 bytes of random data encoded as hex (32 characters). + */ + private generateTaskId(): string { + return randomBytes(16).toString('hex'); + } + + async createTask(taskParams: TaskCreationParams, requestId: RequestId, request: Request, _sessionId?: string): Promise { + // Generate a unique task ID + const taskId = this.generateTaskId(); + // Ensure uniqueness if (this.tasks.has(taskId)) { throw new Error(`Task with ID ${taskId} already exists`); } + // Create task with generated ID and timestamp const task: Task = { taskId, status: 'working', - ttl: metadata.ttl ?? null, + ttl: taskParams.ttl ?? null, createdAt: new Date().toISOString(), - pollInterval: metadata.pollInterval ?? 500 + pollInterval: taskParams.pollInterval ?? 500 }; this.tasks.set(taskId, { @@ -43,11 +55,11 @@ export class InMemoryTaskStore implements TaskStore { }); // Schedule cleanup if ttl is specified - if (metadata.ttl) { + if (taskParams.ttl) { const timer = setTimeout(() => { this.tasks.delete(taskId); this.cleanupTimers.delete(taskId); - }, metadata.ttl); + }, taskParams.ttl); this.cleanupTimers.set(taskId, timer); } diff --git a/src/server/index.test.ts b/src/server/index.test.ts index f01e35864..816230530 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -982,15 +982,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( - { - taskId, - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); + taskId = createdTask.taskId; } if (request.params.name === 'test-tool') { @@ -1184,15 +1183,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( - { - taskId, - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); + taskId = createdTask.taskId; } // Capture the request to verify metadata later @@ -1212,15 +1210,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( - { - taskId, - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); + taskId = createdTask.taskId; } if (request.params.name === 'collect-info') { @@ -1346,15 +1343,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( + const createdTask = await extra.taskStore.createTask( { - taskId, ttl: extra.taskRequestedTtl }, extra.requestId, request ); + taskId = createdTask.taskId; } const result = { action: 'accept', @@ -1438,15 +1434,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( + const createdTask = await extra.taskStore.createTask( { - taskId, ttl: extra.taskRequestedTtl }, extra.requestId, request ); + taskId = createdTask.taskId; } const result = { action: 'accept', @@ -1527,15 +1522,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( + const createdTask = await extra.taskStore.createTask( { - taskId, ttl: extra.taskRequestedTtl }, extra.requestId, request ); + taskId = createdTask.taskId; } const result = { action: 'accept', @@ -1618,15 +1612,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( + const createdTask = await extra.taskStore.createTask( { - taskId, ttl: extra.taskRequestedTtl }, extra.requestId, request ); + taskId = createdTask.taskId; } const result = { action: 'accept', @@ -1735,15 +1728,14 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( - { - taskId, - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); + taskId = createdTask.taskId; } if (request.params.name === 'async-tool') { const delay = (request.params.arguments?.delay as number) || 10; @@ -1976,15 +1968,14 @@ test('should respect client task capabilities', async () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - taskId = `task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - await extra.taskStore.createTask( - { - taskId, - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); + taskId = createdTask.taskId; } const result = { action: 'accept', diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index f777e3fed..37fe5fb29 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -11,7 +11,7 @@ import { Result, ServerCapabilities, Task, - TaskMetadata + TaskCreationParams } from '../types.js'; import { Protocol, mergeCapabilities } from './protocol.js'; import { Transport } from './transport.js'; @@ -37,13 +37,15 @@ function createMockTaskStore(options?: { }): TaskStore & { [K in keyof TaskStore]: MockInstance } { const tasks: Record = {}; return { - createTask: vi.fn((taskMetadata: TaskMetadata, _1: RequestId, _2: Request) => { - const task = (tasks[taskMetadata.taskId] = { - taskId: taskMetadata.taskId, - status: (taskMetadata.status as Task['status'] | undefined) ?? 'working', - ttl: taskMetadata.ttl ?? null, + createTask: vi.fn((taskParams: TaskCreationParams, _1: RequestId, _2: Request) => { + // Generate a unique task ID + const taskId = `test-task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + const task = (tasks[taskId] = { + taskId, + status: 'working', + ttl: taskParams.ttl ?? null, createdAt: new Date().toISOString(), - pollInterval: (taskMetadata.pollInterval as Task['pollInterval'] | undefined) ?? 1000 + pollInterval: taskParams.pollInterval ?? 1000 }); options?.onStatus?.('working'); return Promise.resolve(task); @@ -1126,10 +1128,8 @@ describe('Task-based execution', () => { const mockTaskStore = createMockTaskStore({ onList: () => listedTasks.releaseLatch() }); - await mockTaskStore.createTask( + const task1 = await mockTaskStore.createTask( { - taskId: 'task-1', - status: 'completed', pollInterval: 500 }, 1, @@ -1138,11 +1138,12 @@ describe('Task-based execution', () => { params: {} } ); - await mockTaskStore.createTask( + // Manually set status to completed for this test + await mockTaskStore.updateTaskStatus(task1.taskId, 'completed'); + + const task2 = await mockTaskStore.createTask( { - taskId: 'task-2', - status: 'working', - keepAlive: 60000, + ttl: 60000, pollInterval: 1000 }, 2, @@ -1177,8 +1178,8 @@ describe('Task-based execution', () => { expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(3); expect(sentMessage.result.tasks).toEqual([ - { taskId: 'task-1', status: 'completed', ttl: null, createdAt: expect.any(String), pollInterval: 500 }, - { taskId: 'task-2', status: 'working', ttl: null, createdAt: expect.any(String), pollInterval: 1000 } + { taskId: task1.taskId, status: 'completed', ttl: null, createdAt: expect.any(String), pollInterval: 500 }, + { taskId: task2.taskId, status: 'working', ttl: 60000, createdAt: expect.any(String), pollInterval: 1000 } ]); expect(sentMessage.result._meta).toEqual({}); }); @@ -1188,9 +1189,8 @@ describe('Task-based execution', () => { const mockTaskStore = createMockTaskStore({ onList: () => listedTasks.releaseLatch() }); - await mockTaskStore.createTask( + const task3 = await mockTaskStore.createTask( { - taskId: 'task-3', pollInterval: 500 }, 1, @@ -1227,7 +1227,7 @@ describe('Task-based execution', () => { expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(2); expect(sentMessage.result.tasks).toEqual([ - { taskId: 'task-3', status: 'working', ttl: null, createdAt: expect.any(String), pollInterval: 500 } + { taskId: task3.taskId, status: 'working', ttl: null, createdAt: expect.any(String), pollInterval: 500 } ]); expect(sentMessage.result.nextCursor).toBeUndefined(); expect(sentMessage.result._meta).toEqual({}); @@ -1376,20 +1376,14 @@ describe('Task-based execution', () => { it('should handle tasks/cancel requests and update task status to cancelled', async () => { const taskDeleted = createLatch(); const mockTaskStore = createMockTaskStore(); - const task = await mockTaskStore.createTask( - { - taskId: 'task-to-delete' - }, - 1, - { - method: 'test/method', - params: {} - } - ); + const task = await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); mockTaskStore.getTask.mockResolvedValue(task); mockTaskStore.updateTaskStatus.mockImplementation(async (taskId: string, status: string) => { - if (taskId === 'task-to-delete' && status === 'cancelled') { + if (taskId === task.taskId && status === 'cancelled') { taskDeleted.releaseLatch(); return; } @@ -1413,15 +1407,15 @@ describe('Task-based execution', () => { id: 5, method: 'tasks/cancel', params: { - taskId: 'task-to-delete' + taskId: task.taskId } }); await taskDeleted.waitForLatch(); - expect(mockTaskStore.getTask).toHaveBeenCalledWith('task-to-delete', undefined); + expect(mockTaskStore.getTask).toHaveBeenCalledWith(task.taskId, undefined); expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith( - 'task-to-delete', + task.taskId, 'cancelled', 'Client cancelled task execution.', undefined @@ -1476,19 +1470,16 @@ describe('Task-based execution', () => { it('should return error with code -32602 when trying to cancel a task in terminal status', async () => { const mockTaskStore = createMockTaskStore(); - const completedTask = await mockTaskStore.createTask( - { - taskId: 'completed-task' - }, - 1, - { - method: 'test/method', - params: {} - } - ); + const completedTask = await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); // Set task to completed status + await mockTaskStore.updateTaskStatus(completedTask.taskId, 'completed'); completedTask.status = 'completed'; + // Reset the mock so we can check it's not called during cancellation + mockTaskStore.updateTaskStatus.mockClear(); mockTaskStore.getTask.mockResolvedValue(completedTask); const serverProtocol = new (class extends Protocol { @@ -1508,14 +1499,14 @@ describe('Task-based execution', () => { id: 7, method: 'tasks/cancel', params: { - taskId: 'completed-task' + taskId: completedTask.taskId } }); // Wait a bit for the async handler to complete await new Promise(resolve => setTimeout(resolve, 10)); - expect(mockTaskStore.getTask).toHaveBeenCalledWith('completed-task', undefined); + expect(mockTaskStore.getTask).toHaveBeenCalledWith(completedTask.taskId, undefined); expect(mockTaskStore.updateTaskStatus).not.toHaveBeenCalled(); // eslint-disable-next-line @typescript-eslint/no-explicit-any const sentMessage = sendSpy.mock.calls[0][0] as any; diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 4d6fb0215..fde37d1ef 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -36,7 +36,6 @@ import { RequestInfo, GetTaskResult, TaskCreationParams, - TaskMetadata, RelatedTaskMetadata, CancelledNotification, Task @@ -159,14 +158,15 @@ export type TaskRequestOptions = Omit; */ export interface RequestTaskStore { /** - * Creates a new task with the given metadata and original request. + * Creates a new task with the given creation parameters. + * The implementation generates a unique taskId and createdAt timestamp. * - * @param task - The task creation metadata from the request + * @param taskParams - The task creation parameters from the request (ttl, pollInterval) * @param requestId - The JSON-RPC request ID * @param request - The original request that triggered task creation - * @returns The task state including status, ttl, pollInterval, and optional statusMessage + * @returns The task state including generated taskId, createdAt timestamp, status, ttl, pollInterval, and optional statusMessage */ - createTask(task: TaskMetadata, requestId: RequestId, request: Request): Promise; + createTask(taskParams: TaskCreationParams, requestId: RequestId, request: Request): Promise; /** * Gets the current status of a task. @@ -1081,13 +1081,13 @@ export abstract class Protocol { + createTask: async taskParams => { if (!request) { throw new Error('No request provided'); } - const result = await taskStore.createTask( - task, + const createdTask = await taskStore.createTask( + taskParams, request.id, { method: request.method, @@ -1096,14 +1096,14 @@ export abstract class Protocol { const task = await taskStore.getTask(taskId, sessionId); diff --git a/src/shared/task.ts b/src/shared/task.ts index e3d30e025..fd32aa979 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -1,4 +1,4 @@ -import { Task, TaskMetadata, Request, RequestId, Result } from '../types.js'; +import { Task, TaskCreationParams, Request, RequestId, Result } from '../types.js'; /** * Interface for storing and retrieving task state and results. @@ -8,15 +8,16 @@ import { Task, TaskMetadata, Request, RequestId, Result } from '../types.js'; */ export interface TaskStore { /** - * Creates a new task with the given metadata and original request. + * Creates a new task with the given creation parameters and original request. + * The implementation must generate a unique taskId and createdAt timestamp. * - * @param task - The task creation metadata from the request (includes taskId generated by receiver) + * @param taskParams - The task creation parameters from the request (ttl, pollInterval) * @param requestId - The JSON-RPC request ID * @param request - The original request that triggered task creation * @param sessionId - Optional session ID for binding the task to a specific session - * @returns The task state including status, ttl, pollInterval, and optional statusMessage + * @returns The task state including generated taskId, createdAt timestamp, status, ttl, pollInterval, and optional statusMessage */ - createTask(task: TaskMetadata, requestId: RequestId, request: Request, sessionId?: string): Promise; + createTask(taskParams: TaskCreationParams, requestId: RequestId, request: Request, sessionId?: string): Promise; /** * Gets the current status of a task. diff --git a/src/types.ts b/src/types.ts index 87ec79b13..f669294ee 100644 --- a/src/types.ts +++ b/src/types.ts @@ -52,13 +52,6 @@ export const TaskCreationParamsSchema = z */ .passthrough(); -/** - * @deprecated Use TaskCreationParamsSchema instead. This is kept for backward compatibility. - */ -export const TaskMetadataSchema = TaskCreationParamsSchema.extend({ - taskId: z.string() -}); - /** * Task association metadata, used to signal which task a message originated from. */ @@ -1968,7 +1961,6 @@ export type ProgressNotification = Infer; /* Tasks */ export type Task = Infer; export type TaskCreationParams = Infer; -export type TaskMetadata = Infer; export type RelatedTaskMetadata = Infer; export type CreateTaskResult = Infer; export type TaskCreatedNotification = Infer; From 064568b0b995ec05902f29d8346723689cb3e0c9 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 18 Nov 2025 16:52:39 -0800 Subject: [PATCH 35/84] Preliminary tasks/result implementation updates Full implementation will involve side-channeling SSE which will come later. --- src/client/index.test.ts | 14 ++++---- src/server/index.test.ts | 70 ++++++++++++++++++++-------------------- src/shared/protocol.ts | 43 ++++++++++++++++++++++-- 3 files changed, 82 insertions(+), 45 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 071ff7274..37539d807 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -2154,13 +2154,13 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); + taskId = createdTask.taskId; } if (request.params.name === 'test-tool') { const result = { diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 816230530..d712797ff 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -983,13 +983,13 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); + taskId = createdTask.taskId; } if (request.params.name === 'test-tool') { @@ -1184,13 +1184,13 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); + taskId = createdTask.taskId; } // Capture the request to verify metadata later @@ -1211,13 +1211,13 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); + taskId = createdTask.taskId; } if (request.params.name === 'collect-info') { @@ -1729,13 +1729,13 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); + taskId = createdTask.taskId; } if (request.params.name === 'async-tool') { const delay = (request.params.arguments?.delay as number) || 10; @@ -1969,13 +1969,13 @@ test('should respect client task capabilities', async () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + request + ); + taskId = createdTask.taskId; } const result = { action: 'accept', diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index fde37d1ef..b965e89ce 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -369,15 +369,52 @@ export abstract class Protocol { + // Helper function to wait with abort signal support + const waitWithAbort = (ms: number, signal: AbortSignal): Promise => { + return new Promise((resolve, reject) => { + if (signal.aborted) { + reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled while waiting for task completion')); + return; + } + + const timeoutId = setTimeout(() => { + signal.removeEventListener('abort', abortHandler); + resolve(); + }, ms); + + const abortHandler = () => { + clearTimeout(timeoutId); + reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled while waiting for task completion')); + }; + + signal.addEventListener('abort', abortHandler, { once: true }); + }); + }; + const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); if (!task) { - throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + throw new McpError(ErrorCode.InvalidParams, `Task not found: ${request.params.taskId}`); } - if (task.status !== 'completed') { - throw new McpError(ErrorCode.InvalidParams, `Cannot retrieve result: Task status is '${task.status}', not 'completed'`); + // If task is not in a terminal state, block until it reaches one + if (!isTerminal(task.status)) { + // Poll for task completion + let currentTask = task; + while (!isTerminal(currentTask.status)) { + // Wait for the poll interval before checking again + await waitWithAbort(currentTask.pollInterval ?? 5000, extra.signal); + + // Get updated task status + const updatedTask = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); + if (!updatedTask) { + throw new McpError(ErrorCode.InvalidParams, `Task not found: ${request.params.taskId}`); + } + currentTask = updatedTask; + } } + // Task is now in a terminal state (completed, failed, or cancelled) + // Retrieve and return the result const result = await this._taskStore!.getTaskResult(request.params.taskId, extra.sessionId); return { ...result, From 9296ea8267981117229ef1867d215f719aa8ab15 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 18 Nov 2025 17:11:13 -0800 Subject: [PATCH 36/84] Handle input_required state in PendingRequest We'll be rewriting this to use an async generator soon, but this implementation works for now. --- src/shared/request.test.ts | 253 +++++++++++++++++++++++++++++++++++++ src/shared/request.ts | 15 ++- 2 files changed, 265 insertions(+), 3 deletions(-) create mode 100644 src/shared/request.test.ts diff --git a/src/shared/request.test.ts b/src/shared/request.test.ts new file mode 100644 index 000000000..309652947 --- /dev/null +++ b/src/shared/request.test.ts @@ -0,0 +1,253 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { PendingRequest } from './request.js'; +import { Protocol } from './protocol.js'; +import { Request, Notification, Result, GetTaskResult } from '../types.js'; +import { z, ZodType } from 'zod'; + +// Mock Protocol class +class MockProtocol extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + + // Expose methods for testing + public mockGetTask = vi.fn(); + public mockGetTaskResult = vi.fn(); + + async getTask(params: { taskId: string }): Promise { + return this.mockGetTask(params); + } + + async getTaskResult(params: { taskId: string }, _resultSchema: ZodType): Promise { + return this.mockGetTaskResult(params, _resultSchema) as Promise; + } +} + +describe('PendingRequest', () => { + let protocol: MockProtocol; + const mockResultSchema = z.object({ result: z.string() }); + + beforeEach(() => { + protocol = new MockProtocol(); + }); + + describe('input_required status handling', () => { + it('should preemptively call tasks/result when input_required status is encountered', async () => { + // Setup: Create a task that transitions to input_required + const taskId = 'test-task-123'; + const expectedResult = { result: 'completed after input' }; + + // Mock getTask to return input_required status + protocol.mockGetTask.mockResolvedValueOnce({ + taskId, + status: 'input_required', + ttl: null, + createdAt: new Date().toISOString(), + pollInterval: 1000 + }); + + // Mock getTaskResult to return the final result + protocol.mockGetTaskResult.mockResolvedValueOnce(expectedResult); + + // Create a PendingRequest with a task ID + const resultHandle = Promise.resolve(expectedResult); + const pendingRequest = new PendingRequest(protocol, resultHandle, mockResultSchema, taskId, 5000); + + // Execute: Call result() which should trigger taskHandler + const result = await pendingRequest.result({ + onTaskCreated: vi.fn(), + onTaskStatus: vi.fn() + }); + + // Verify: getTask was called once + expect(protocol.mockGetTask).toHaveBeenCalledTimes(1); + expect(protocol.mockGetTask).toHaveBeenCalledWith({ taskId }); + + // Verify: getTaskResult was called immediately after detecting input_required + expect(protocol.mockGetTaskResult).toHaveBeenCalledTimes(1); + expect(protocol.mockGetTaskResult).toHaveBeenCalledWith({ taskId }, mockResultSchema); + + // Verify: Result is correct + expect(result).toEqual(expectedResult); + }); + + it('should call onTaskStatus before calling tasks/result for input_required', async () => { + const taskId = 'test-task-456'; + const expectedResult = { result: 'completed' }; + const onTaskStatus = vi.fn(); + + const inputRequiredTask: GetTaskResult = { + taskId, + status: 'input_required', + ttl: null, + createdAt: new Date().toISOString(), + pollInterval: 1000 + }; + + protocol.mockGetTask.mockResolvedValueOnce(inputRequiredTask); + protocol.mockGetTaskResult.mockResolvedValueOnce(expectedResult); + + const resultHandle = Promise.resolve(expectedResult); + const pendingRequest = new PendingRequest(protocol, resultHandle, mockResultSchema, taskId, 5000); + + await pendingRequest.result({ + onTaskCreated: vi.fn(), + onTaskStatus + }); + + // Verify: onTaskStatus was called with the input_required task + expect(onTaskStatus).toHaveBeenCalledWith(inputRequiredTask); + expect(onTaskStatus).toHaveBeenCalledBefore(protocol.mockGetTaskResult); + }); + + it('should not poll again after encountering input_required status', async () => { + const taskId = 'test-task-789'; + const expectedResult = { result: 'completed' }; + + protocol.mockGetTask.mockResolvedValueOnce({ + taskId, + status: 'input_required', + ttl: null, + createdAt: new Date().toISOString(), + pollInterval: 100 // Short interval to test that we don't wait + }); + + protocol.mockGetTaskResult.mockResolvedValueOnce(expectedResult); + + const resultHandle = Promise.resolve(expectedResult); + const pendingRequest = new PendingRequest(protocol, resultHandle, mockResultSchema, taskId, 5000); + + const startTime = Date.now(); + await pendingRequest.result({ + onTaskCreated: vi.fn(), + onTaskStatus: vi.fn() + }); + const endTime = Date.now(); + + // Verify: getTask was only called once (no polling) + expect(protocol.mockGetTask).toHaveBeenCalledTimes(1); + + // Verify: The operation completed quickly without waiting for pollInterval + expect(endTime - startTime).toBeLessThan(100); + }); + + it('should continue normal polling for working status before input_required', async () => { + const taskId = 'test-task-abc'; + const expectedResult = { result: 'completed' }; + + // First poll: working status + protocol.mockGetTask.mockResolvedValueOnce({ + taskId, + status: 'working', + ttl: null, + createdAt: new Date().toISOString(), + pollInterval: 10 + }); + + // Second poll: input_required status + protocol.mockGetTask.mockResolvedValueOnce({ + taskId, + status: 'input_required', + ttl: null, + createdAt: new Date().toISOString(), + pollInterval: 10 + }); + + protocol.mockGetTaskResult.mockResolvedValueOnce(expectedResult); + + const resultHandle = Promise.resolve(expectedResult); + const pendingRequest = new PendingRequest(protocol, resultHandle, mockResultSchema, taskId, 5000); + + await pendingRequest.result({ + onTaskCreated: vi.fn(), + onTaskStatus: vi.fn() + }); + + // Verify: getTask was called twice (once for working, once for input_required) + expect(protocol.mockGetTask).toHaveBeenCalledTimes(2); + + // Verify: getTaskResult was called after input_required was detected + expect(protocol.mockGetTaskResult).toHaveBeenCalledTimes(1); + }); + + it('should handle terminal status normally without input_required', async () => { + const taskId = 'test-task-def'; + const expectedResult = { result: 'completed' }; + + // Task is already completed + protocol.mockGetTask.mockResolvedValueOnce({ + taskId, + status: 'completed', + ttl: null, + createdAt: new Date().toISOString(), + pollInterval: 1000 + }); + + protocol.mockGetTaskResult.mockResolvedValueOnce(expectedResult); + + const resultHandle = Promise.resolve(expectedResult); + const pendingRequest = new PendingRequest(protocol, resultHandle, mockResultSchema, taskId, 5000); + + await pendingRequest.result({ + onTaskCreated: vi.fn(), + onTaskStatus: vi.fn() + }); + + // Verify: Normal flow - getTask once, then getTaskResult + expect(protocol.mockGetTask).toHaveBeenCalledTimes(1); + expect(protocol.mockGetTaskResult).toHaveBeenCalledTimes(1); + }); + }); + + describe('normal task polling', () => { + it('should poll until terminal status is reached', async () => { + const taskId = 'test-task-polling'; + const expectedResult = { result: 'completed' }; + + // First poll: working + protocol.mockGetTask.mockResolvedValueOnce({ + taskId, + status: 'working', + ttl: null, + createdAt: new Date().toISOString(), + pollInterval: 10 + }); + + // Second poll: still working + protocol.mockGetTask.mockResolvedValueOnce({ + taskId, + status: 'working', + ttl: null, + createdAt: new Date().toISOString(), + pollInterval: 10 + }); + + // Third poll: completed + protocol.mockGetTask.mockResolvedValueOnce({ + taskId, + status: 'completed', + ttl: null, + createdAt: new Date().toISOString(), + pollInterval: 10 + }); + + protocol.mockGetTaskResult.mockResolvedValueOnce(expectedResult); + + const resultHandle = Promise.resolve(expectedResult); + const pendingRequest = new PendingRequest(protocol, resultHandle, mockResultSchema, taskId, 5000); + + await pendingRequest.result({ + onTaskCreated: vi.fn(), + onTaskStatus: vi.fn() + }); + + // Verify: getTask was called three times + expect(protocol.mockGetTask).toHaveBeenCalledTimes(3); + + // Verify: getTaskResult was called once after terminal status + expect(protocol.mockGetTaskResult).toHaveBeenCalledTimes(1); + }); + }); +}); diff --git a/src/shared/request.ts b/src/shared/request.ts index 87c80b102..7333b57d6 100644 --- a/src/shared/request.ts +++ b/src/shared/request.ts @@ -67,9 +67,18 @@ export class PendingRequest - setTimeout(resolve, task.pollInterval ?? this.defaultTaskPollInterval ?? DEFAULT_TASK_POLLING_INTERVAL) - ); + + // Handle input_required status: preemptively call tasks/result instead of continuing to poll + // This allows the receiver to block and wait for user input before returning the result + if (task.status === 'input_required') { + return await this.protocol.getTaskResult({ taskId: taskId }, this.resultSchema); + } + + if (!isTerminal(task.status)) { + await new Promise(resolve => + setTimeout(resolve, task.pollInterval ?? this.defaultTaskPollInterval ?? DEFAULT_TASK_POLLING_INTERVAL) + ); + } } while (!isTerminal(task.status)); // Process result From ffc12829d3190be35fc76ee9701768e64306867b Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 18 Nov 2025 17:52:22 -0800 Subject: [PATCH 37/84] Implement task status notifications --- src/shared/protocol.test.ts | 50 +++++++++++++++++++++++++++++++++++++ src/shared/protocol.ts | 26 +++++++++++++++++-- 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 37fe5fb29..65c6c101e 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -1547,4 +1547,54 @@ describe('Task-based execution', () => { expect(result._meta).toBeDefined(); }); }); + + describe('task status notifications', () => { + it('should call getTask after updateTaskStatus to enable notification sending', async () => { + const mockTaskStore = createMockTaskStore(); + + // Create a task first + const task = await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + + await serverProtocol.connect(serverTransport); + + // Simulate cancelling the task + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/cancel', + params: { + taskId: task.taskId + } + }); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify that updateTaskStatus was called + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith( + task.taskId, + 'cancelled', + 'Client cancelled task execution.', + undefined + ); + + // Verify that getTask was called after updateTaskStatus + // This is done by the RequestTaskStore wrapper to get the updated task for the notification + const getTaskCalls = mockTaskStore.getTask.mock.calls; + const lastGetTaskCall = getTaskCalls[getTaskCalls.length - 1]; + expect(lastGetTaskCall[0]).toBe(task.taskId); + }); + }); }); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index b965e89ce..fe299235d 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -1158,8 +1158,19 @@ export abstract class Protocol { - return taskStore.storeTaskResult(taskId, result, sessionId); + storeTaskResult: async (taskId, result) => { + await taskStore.storeTaskResult(taskId, result, sessionId); + + // Get updated task state and send notification + const task = await taskStore.getTask(taskId, sessionId); + if (task) { + await this.notification({ + method: 'notifications/tasks/status', + params: { + task + } + } as unknown as SendNotificationT); + } }, getTaskResult: taskId => { return taskStore.getTaskResult(taskId, sessionId); @@ -1182,6 +1193,17 @@ export abstract class Protocol Date: Tue, 18 Nov 2025 17:53:16 -0800 Subject: [PATCH 38/84] Remove task creation notification --- src/shared/protocol.ts | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index fe299235d..3689d7475 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -1123,7 +1123,7 @@ export abstract class Protocol { const task = await taskStore.getTask(taskId, sessionId); From b945454ff527d4dcf553247ac83e29cb49c53519 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 18 Nov 2025 17:54:45 -0800 Subject: [PATCH 39/84] Remove types for task creation notification --- src/types.ts | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/types.ts b/src/types.ts index f669294ee..e900697f0 100644 --- a/src/types.ts +++ b/src/types.ts @@ -651,13 +651,6 @@ export const CreateTaskResultSchema = ResultSchema.extend({ task: TaskSchema }); -/** - * An out-of-band notification used to inform the receiver of a task being created. - */ -export const TaskCreatedNotificationSchema = NotificationSchema.extend({ - method: z.literal('notifications/tasks/created') -}); - /** * Parameters for task status notification. */ @@ -1805,8 +1798,7 @@ export const ClientNotificationSchema = z.union([ CancelledNotificationSchema, ProgressNotificationSchema, InitializedNotificationSchema, - RootsListChangedNotificationSchema, - TaskCreatedNotificationSchema + RootsListChangedNotificationSchema ]); export const ClientResultSchema = z.union([ @@ -1837,8 +1829,7 @@ export const ServerNotificationSchema = z.union([ ResourceUpdatedNotificationSchema, ResourceListChangedNotificationSchema, ToolListChangedNotificationSchema, - PromptListChangedNotificationSchema, - TaskCreatedNotificationSchema + PromptListChangedNotificationSchema ]); export const ServerResultSchema = z.union([ @@ -1963,7 +1954,6 @@ export type Task = Infer; export type TaskCreationParams = Infer; export type RelatedTaskMetadata = Infer; export type CreateTaskResult = Infer; -export type TaskCreatedNotification = Infer; export type TaskStatusNotificationParams = Infer; export type TaskStatusNotification = Infer; export type GetTaskRequest = Infer; From 1c7b332c90936ecdaf9eca099a00df58a21a65b9 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 18 Nov 2025 18:00:27 -0800 Subject: [PATCH 40/84] Use actual notification schema --- src/shared/protocol.ts | 14 +++++++++----- src/types.ts | 6 ++++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 3689d7475..5b9fe16b5 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -38,7 +38,9 @@ import { TaskCreationParams, RelatedTaskMetadata, CancelledNotification, - Task + Task, + TaskStatusNotification, + TaskStatusNotificationSchema } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; @@ -1147,12 +1149,13 @@ export abstract class Protocol { @@ -1180,12 +1183,13 @@ export abstract class Protocol Date: Wed, 19 Nov 2025 10:05:43 -0800 Subject: [PATCH 41/84] Remove related-task metadata from messages that already include it as a parameter --- src/shared/protocol.test.ts | 245 ++++++++++++++++++++++++++++++++++++ src/shared/protocol.ts | 28 +++-- 2 files changed, 264 insertions(+), 9 deletions(-) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 65c6c101e..d2c93937a 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -1597,4 +1597,249 @@ describe('Task-based execution', () => { expect(lastGetTaskCall[0]).toBe(task.taskId); }); }); + + describe('task metadata handling', () => { + it('should NOT include related-task metadata in tasks/get response', async () => { + const mockTaskStore = createMockTaskStore(); + + // Create a task first + const task = await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + // Request task status + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/get', + params: { + taskId: task.taskId + } + }); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify response does NOT include related-task metadata + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + result: expect.objectContaining({ + taskId: task.taskId, + status: 'working' + }) + }) + ); + + // Verify _meta is not present or doesn't contain RELATED_TASK_META_KEY + const response = sendSpy.mock.calls[0][0] as { result?: { _meta?: Record } }; + expect(response.result?._meta?.[RELATED_TASK_META_KEY]).toBeUndefined(); + }); + + it('should NOT include related-task metadata in tasks/list response', async () => { + const mockTaskStore = createMockTaskStore(); + + // Create a task first + await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + // Request task list + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/list', + params: {} + }); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify response does NOT include related-task metadata + const response = sendSpy.mock.calls[0][0] as { result?: { _meta?: Record } }; + expect(response.result?._meta).toEqual({}); + }); + + it('should NOT include related-task metadata in tasks/cancel response', async () => { + const mockTaskStore = createMockTaskStore(); + + // Create a task first + const task = await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + // Cancel the task + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/cancel', + params: { + taskId: task.taskId + } + }); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify response does NOT include related-task metadata + const response = sendSpy.mock.calls[0][0] as { result?: { _meta?: Record } }; + expect(response.result?._meta).toEqual({}); + }); + + it('should include related-task metadata in tasks/result response', async () => { + const mockTaskStore = createMockTaskStore(); + + // Create a task and complete it + const task = await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); + + const testResult = { + content: [{ type: 'text', text: 'test result' }] + }; + + await mockTaskStore.storeTaskResult(task.taskId, testResult); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + // Request task result + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/result', + params: { + taskId: task.taskId + } + }); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify response DOES include related-task metadata + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + result: expect.objectContaining({ + content: testResult.content, + _meta: expect.objectContaining({ + [RELATED_TASK_META_KEY]: { + taskId: task.taskId + } + }) + }) + }) + ); + }); + + it('should propagate related-task metadata to handler sendRequest and sendNotification', async () => { + const mockTaskStore = createMockTaskStore(); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + // Set up a handler that uses sendRequest and sendNotification + serverProtocol.setRequestHandler(CallToolRequestSchema, async (_request, extra) => { + // Send a notification using the extra.sendNotification + await extra.sendNotification({ + method: 'notifications/message', + params: { level: 'info', data: 'test' } + }); + + return { + content: [{ type: 'text', text: 'done' }] + }; + }); + + // Send a request with related-task metadata + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { + name: 'test-tool', + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: 'parent-task-123' + } + } + } + }); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify the notification includes related-task metadata + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'notifications/message', + params: expect.objectContaining({ + _meta: expect.objectContaining({ + [RELATED_TASK_META_KEY]: { + taskId: 'parent-task-123' + } + }) + }) + }), + expect.any(Object) + ); + }); + }); }); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 5b9fe16b5..ce457dbc4 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -359,14 +359,11 @@ export abstract class Protocol = { signal: abortController.signal, sessionId: capturedTransport?.sessionId, _meta: request.params?._meta, sendNotification: async notification => { - await this.notification(notification, { relatedRequestId: request.id }); + // Include related-task metadata if this request is part of a task (Requirement 6.1) + const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; + if (relatedTaskId) { + notificationOptions.relatedTask = { taskId: relatedTaskId }; + } + await this.notification(notification, notificationOptions); }, sendRequest: async (r, resultSchema, options?) => { - return await this.request(r, resultSchema, { ...options, relatedRequestId: request.id }); + // Include related-task metadata if this request is part of a task (Requirement 6.1) + const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; + if (relatedTaskId && !requestOptions.relatedTask) { + requestOptions.relatedTask = { taskId: relatedTaskId }; + } + return await this.request(r, resultSchema, requestOptions); }, authInfo: extra?.authInfo, requestId: request.id, requestInfo: extra?.requestInfo, - taskId: undefined, + taskId: relatedTaskId, taskStore: taskStore, taskRequestedTtl: taskCreationParams?.ttl }; From ad220e0027bb628bae8233fe00a8cca80bea2366 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 19 Nov 2025 12:18:58 -0800 Subject: [PATCH 42/84] Update taskHint implementation on server --- src/server/mcp.test.ts | 776 ++++++++++++++++++++++++++++++++++++++++- src/server/mcp.ts | 236 +++++++++---- 2 files changed, 934 insertions(+), 78 deletions(-) diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 428c25ae8..629874e29 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -5,6 +5,7 @@ import { getDisplayName } from '../shared/metadataUtils.js'; import { UriTemplate } from '../shared/uriTemplate.js'; import { CallToolResultSchema, + type CallToolResult, CompleteResultSchema, ElicitRequestSchema, GetPromptResultSchema, @@ -19,6 +20,23 @@ import { } from '../types.js'; import { completable } from './completable.js'; import { McpServer, ResourceTemplate } from './mcp.js'; +import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; + +function createLatch() { + let latch = false; + const waitForLatch = async () => { + while (!latch) { + await new Promise(resolve => setTimeout(resolve, 0)); + } + }; + + return { + releaseLatch: () => { + latch = true; + }, + waitForLatch + }; +} describe('McpServer', () => { /*** @@ -556,7 +574,7 @@ describe('tool()', () => { inputSchema: { name: z.string(), value: z.number() } }, async ({ name, value }) => ({ - content: [{ type: 'text', text: `${name}: ${value}` }] + content: [{ type: 'text' as const, text: `${name}: ${value}` }] }) ); @@ -716,7 +734,7 @@ describe('tool()', () => { }); mcpServer.tool('test', { name: z.string() }, { title: 'Test Tool', readOnlyHint: true }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [{ type: 'text' as const, text: `Hello, ${name}!` }] })); mcpServer.registerTool( @@ -726,7 +744,7 @@ describe('tool()', () => { annotations: { title: 'Test Tool', readOnlyHint: true } }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [{ type: 'text' as const, text: `Hello, ${name}!` }] }) ); @@ -770,7 +788,7 @@ describe('tool()', () => { { name: z.string() }, { title: 'Complete Test Tool', readOnlyHint: true, openWorldHint: false, taskHint: 'never' }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [{ type: 'text' as const, text: `Hello, ${name}!` }] }) ); @@ -787,7 +805,7 @@ describe('tool()', () => { } }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [{ type: 'text' as const, text: `Hello, ${name}!` }] }) ); @@ -840,7 +858,7 @@ describe('tool()', () => { taskHint: 'never' }, async () => ({ - content: [{ type: 'text', text: 'Test response' }] + content: [{ type: 'text' as const, text: 'Test response' }] }) ); @@ -1651,7 +1669,7 @@ describe('tool()', () => { _meta: metaData }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [{ type: 'text' as const, text: `Hello, ${name}!` }] }) ); @@ -1687,7 +1705,7 @@ describe('tool()', () => { inputSchema: { name: z.string() } }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [{ type: 'text' as const, text: `Hello, ${name}!` }] }) ); @@ -3684,7 +3702,7 @@ describe('Tool title precedence', () => { // Tool 1: Only name mcpServer.tool('tool_name_only', async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [{ type: 'text' as const, text: 'Response' }] })); // Tool 2: Name and annotations.title @@ -3695,7 +3713,7 @@ describe('Tool title precedence', () => { title: 'Annotations Title' }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [{ type: 'text' as const, text: 'Response' }] }) ); @@ -4307,11 +4325,11 @@ describe('Tools with union and intersection schemas', () => { server.registerTool('contact', { inputSchema: unionSchema }, async args => { if (args.type === 'email') { return { - content: [{ type: 'text', text: `Email contact: ${args.email}` }] + content: [{ type: 'text' as const, text: `Email contact: ${args.email}` }] }; } else { return { - content: [{ type: 'text', text: `Phone contact: ${args.phone}` }] + content: [{ type: 'text' as const, text: `Phone contact: ${args.phone}` }] }; } }); @@ -4477,7 +4495,7 @@ describe('Tools with union and intersection schemas', () => { server.registerTool('union-test', { inputSchema: unionSchema }, async () => { return { - content: [{ type: 'text', text: 'Success' }] + content: [{ type: 'text' as const, text: 'Success' }] }; }); @@ -4522,3 +4540,735 @@ describe('Tools with union and intersection schemas', () => { ); }); }); + +describe('Tool-level task hints with automatic polling wrapper', () => { + test('should return error for tool with taskHint "always" called without task augmentation', async () => { + const taskStore = new InMemoryTaskStore(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + // Register a task-based tool with taskHint "always" BEFORE connecting + mcpServer.registerToolTask( + 'long-running-task', + { + description: 'A long running task', + inputSchema: { + input: z.string() + }, + annotations: { + taskHint: 'always' as unknown as 'never' // override to allow violating build-time constraints + } + }, + { + createTask: async ({ input }, extra) => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }, extra.requestId, { + method: 'tools/call', + params: { name: 'long-running-task', arguments: { input } } + }); + + // Capture taskStore for use in setTimeout + const store = extra.taskStore; + + // Simulate async work + setTimeout(async () => { + await store.storeTaskResult(task.taskId, { + content: [{ type: 'text' as const, text: `Processed: ${input}` }] + }); + }, 200); + + return { task }; + }, + getTask: async (_args, extra) => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error('Task not found'); + } + return task; + }, + getTaskResult: async (_input, extra) => { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as CallToolResult; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool WITHOUT task augmentation - should return error + const result = await client.callTool( + { + name: 'long-running-task', + arguments: { input: 'test data' } + }, + CallToolResultSchema + ); + + // Should receive error result + expect(result.isError).toBe(true); + const content = result.content as TextContent[]; + expect(content[0].text).toContain('requires task augmentation'); + + taskStore.cleanup(); + }); + + test('should automatically poll and return CallToolResult for tool with taskHint "optional" called without task augmentation', async () => { + const taskStore = new InMemoryTaskStore(); + const { releaseLatch, waitForLatch } = createLatch(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + // Register a task-based tool with taskHint "optional" BEFORE connecting + mcpServer.registerToolTask( + 'optional-task', + { + description: 'An optional task', + inputSchema: { + value: z.number() + }, + annotations: { + taskHint: 'optional' as unknown as 'never' // override to allow violating build-time constraints + } + }, + { + createTask: async ({ value }, extra) => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }, extra.requestId, { + method: 'tools/call', + params: { name: 'optional-task', arguments: { value } } + }); + + // Capture taskStore for use in setTimeout + const store = extra.taskStore; + + // Simulate async work + setTimeout(async () => { + await store.storeTaskResult(task.taskId, { + content: [{ type: 'text' as const, text: `Result: ${value * 2}` }] + }); + releaseLatch(); + }, 150); + + return { task }; + }, + getTask: async (_args, extra) => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error('Task not found'); + } + return task; + }, + getTaskResult: async (_value, extra) => { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as CallToolResult; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool WITHOUT task augmentation + const result = await client.callTool( + { + name: 'optional-task', + arguments: { value: 21 } + }, + CallToolResultSchema + ); + + // Should receive CallToolResult directly, not CreateTaskResult + expect(result).toHaveProperty('content'); + expect(result.content).toEqual([{ type: 'text' as const, text: 'Result: 42' }]); + expect(result).not.toHaveProperty('task'); + + // Wait for async operations to complete + await waitForLatch(); + taskStore.cleanup(); + }); + + test('should return CreateTaskResult when tool with taskHint "always" is called WITH task augmentation', async () => { + const taskStore = new InMemoryTaskStore(); + const { releaseLatch, waitForLatch } = createLatch(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + // Register a task-based tool with taskHint "always" BEFORE connecting + mcpServer.registerToolTask( + 'task-tool', + { + description: 'A task tool', + inputSchema: { + data: z.string() + }, + annotations: { + taskHint: 'always' as unknown as 'never' // override to allow violating build-time constraints + } + }, + { + createTask: async ({ data }, extra) => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }, extra.requestId, { + method: 'tools/call', + params: { name: 'task-tool', arguments: { data } } + }); + + // Capture taskStore for use in setTimeout + const store = extra.taskStore; + + // Simulate async work + setTimeout(async () => { + await store.storeTaskResult(task.taskId, { + content: [{ type: 'text' as const, text: `Completed: ${data}` }] + }); + releaseLatch(); + }, 200); + + return { task }; + }, + getTask: async (_args, extra) => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error('Task not found'); + } + return task; + }, + getTaskResult: async (_data, extra) => { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as CallToolResult; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool WITH task augmentation + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'task-tool', + arguments: { data: 'test' }, + task: { ttl: 60000 } + } + }, + z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.union([z.number(), z.null()]), + createdAt: z.string(), + pollInterval: z.number().optional() + }) + }) + ); + + // Should receive CreateTaskResult with task field + expect(result).toHaveProperty('task'); + expect(result.task).toHaveProperty('taskId'); + expect(result.task.status).toBe('working'); + + // Wait for async operations to complete + await waitForLatch(); + taskStore.cleanup(); + }); + + test('should throw error if tool with taskHint "always" is not registered with registerToolTask', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Register a regular tool with taskHint "always" (incorrect usage) BEFORE connecting + mcpServer.registerTool( + 'bad-tool', + { + description: 'A tool with incorrect taskHint', + annotations: { + taskHint: 'always' + } + }, + async () => ({ + content: [{ type: 'text' as const, text: 'Should not work' }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool - should return error result + const result = await client.callTool( + { + name: 'bad-tool', + arguments: {} + }, + CallToolResultSchema + ); + + expect(result.isError).toBe(true); + const content = result.content as TextContent[]; + expect(content[0].text).toContain("has taskHint 'always' but was not registered with registerToolTask"); + }); + + test('should throw error if tool with taskHint "optional" is not registered with registerToolTask', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Register a regular tool with taskHint "optional" (incorrect usage) BEFORE connecting + mcpServer.registerTool( + 'bad-optional-tool', + { + description: 'A tool with incorrect taskHint', + annotations: { + taskHint: 'optional' + } + }, + async () => ({ + content: [{ type: 'text' as const, text: 'Should not work' }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool - should return error result + const result = await client.callTool( + { + name: 'bad-optional-tool', + arguments: {} + }, + CallToolResultSchema + ); + + expect(result.isError).toBe(true); + const content = result.content as TextContent[]; + expect(content[0].text).toContain("has taskHint 'optional' but was not registered with registerToolTask"); + }); + + test('should work normally for tool with taskHint "never"', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Register a regular tool with taskHint "never" BEFORE connecting + mcpServer.registerTool( + 'normal-tool', + { + description: 'A normal tool', + inputSchema: { + message: z.string() + }, + annotations: { + taskHint: 'never' + } + }, + async ({ message }) => ({ + content: [{ type: 'text' as const, text: `Echo: ${message}` }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool - should work normally + const result = await client.callTool( + { + name: 'normal-tool', + arguments: { message: 'hello' } + }, + CallToolResultSchema + ); + + expect(result.content).toEqual([{ type: 'text' as const, text: 'Echo: hello' }]); + }); + + test('should work normally for tool without taskHint', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Register a regular tool without taskHint BEFORE connecting + mcpServer.registerTool( + 'simple-tool', + { + description: 'A simple tool', + inputSchema: { + value: z.number() + } + }, + async ({ value }) => ({ + content: [{ type: 'text' as const, text: `Value: ${value}` }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool - should work normally + const result = await client.callTool( + { + name: 'simple-tool', + arguments: { value: 42 } + }, + CallToolResultSchema + ); + + expect(result.content).toEqual([{ type: 'text' as const, text: 'Value: 42' }]); + }); + + test('should handle task failures during automatic polling', async () => { + const taskStore = new InMemoryTaskStore(); + const { releaseLatch, waitForLatch } = createLatch(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + // Register a task-based tool that fails BEFORE connecting + mcpServer.registerToolTask( + 'failing-task', + { + description: 'A failing task', + annotations: { + taskHint: 'optional' as unknown as 'never' // override to allow violating build-time constraints + } + }, + { + createTask: async extra => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }, extra.requestId, { + method: 'tools/call', + params: { name: 'failing-task', arguments: {} } + }); + + // Capture taskStore for use in setTimeout + const store = extra.taskStore; + + // Simulate async failure + setTimeout(async () => { + await store.updateTaskStatus(task.taskId, 'failed', 'Task failed'); + await store.storeTaskResult(task.taskId, { + content: [{ type: 'text' as const, text: 'Error occurred' }], + isError: true + }); + releaseLatch(); + }, 150); + + return { task }; + }, + getTask: async extra => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error('Task not found'); + } + return task; + }, + getTaskResult: async extra => { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as CallToolResult; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool WITHOUT task augmentation + const result = await client.callTool( + { + name: 'failing-task', + arguments: {} + }, + CallToolResultSchema + ); + + // Should receive the error result + expect(result).toHaveProperty('content'); + expect(result.content).toEqual([{ type: 'text' as const, text: 'Error occurred' }]); + expect(result.isError).toBe(true); + + // Wait for async operations to complete + await waitForLatch(); + taskStore.cleanup(); + }); + + test('should handle task cancellation during automatic polling', async () => { + const taskStore = new InMemoryTaskStore(); + const { releaseLatch, waitForLatch } = createLatch(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + // Register a task-based tool that gets cancelled BEFORE connecting + mcpServer.registerToolTask( + 'cancelled-task', + { + description: 'A task that gets cancelled', + annotations: { + taskHint: 'optional' as unknown as 'never' // override to allow violating build-time constraints + } + }, + { + createTask: async extra => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }, extra.requestId, { + method: 'tools/call', + params: { name: 'cancelled-task', arguments: {} } + }); + + // Capture taskStore for use in setTimeout + const store = extra.taskStore; + + // Simulate async cancellation + setTimeout(async () => { + await store.updateTaskStatus(task.taskId, 'cancelled', 'Task was cancelled'); + await store.storeTaskResult(task.taskId, { + content: [{ type: 'text' as const, text: 'Task cancelled' }] + }); + releaseLatch(); + }, 150); + + return { task }; + }, + getTask: async extra => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error('Task not found'); + } + return task; + }, + getTaskResult: async extra => { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as CallToolResult; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool WITHOUT task augmentation + const result = await client.callTool( + { + name: 'cancelled-task', + arguments: {} + }, + CallToolResultSchema + ); + + // Should receive the cancellation result + expect(result).toHaveProperty('content'); + expect(result.content).toEqual([{ type: 'text' as const, text: 'Task cancelled' }]); + + // Wait for async operations to complete + await waitForLatch(); + taskStore.cleanup(); + }); +}); diff --git a/src/server/mcp.ts b/src/server/mcp.ts index a1bcd3d9f..aaaa5595c 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -37,7 +37,8 @@ import { CompleteRequestPrompt, CompleteRequestResourceTemplate, assertCompleteRequestPrompt, - assertCompleteRequestResourceTemplate + assertCompleteRequestResourceTemplate, + CallToolRequest } from '../types.js'; import { Completable, CompletableDef } from './completable.js'; import { UriTemplate, Variables } from '../shared/uriTemplate.js'; @@ -133,89 +134,55 @@ export class McpServer { ); this.server.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { - const tool = this._registeredTools[request.params.name]; - - let result: CallToolResult | CreateTaskResult; - try { + const tool = this._registeredTools[request.params.name]; if (!tool) { throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} not found`); } - if (!tool.enabled) { throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} disabled`); } const isTaskRequest = !!request.params.task; - if (tool.inputSchema) { - const parseResult = await tool.inputSchema.safeParseAsync(request.params.arguments); - if (!parseResult.success) { - throw new McpError( - ErrorCode.InvalidParams, - `Input validation error: Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}` - ); - } - - const args = parseResult.data; + const taskHint = tool.annotations?.taskHint; + const isTaskHandler = 'createTask' in (tool.handler as AnyToolHandler); - const handler = tool.handler as AnyToolHandler; - if ('createTask' in handler) { - const cb = handler.createTask; - if (!extra.taskStore) { - throw new Error('No task store provided.'); - } + // Validate task hint configuration + if ((taskHint === 'always' || taskHint === 'optional') && !isTaskHandler) { + throw new McpError( + ErrorCode.InternalError, + `Tool ${request.params.name} has taskHint '${taskHint}' but was not registered with registerToolTask` + ); + } - // Needed to show the compiler this field exists - const taskExtra = { ...extra, taskStore: extra.taskStore }; - result = await Promise.resolve(cb(args, taskExtra)); - } else { - const cb = handler; - result = await Promise.resolve(cb(args, extra)); - } - } else { - const handler = tool.handler as AnyToolHandler; - if ('createTask' in handler) { - const cb = handler.createTask; - if (!extra.taskStore) { - throw new Error('No task store provided.'); - } + // Handle taskHint 'always' without task augmentation + if (taskHint === 'always' && !isTaskRequest) { + throw new McpError( + ErrorCode.MethodNotFound, + `Tool ${request.params.name} requires task augmentation (taskHint: 'always')` + ); + } - // Needed to show the compiler this field exists - const taskExtra = { ...extra, taskStore: extra.taskStore }; - result = await Promise.resolve(cb(taskExtra)); - } else { - const cb = handler; - result = await Promise.resolve(cb(extra)); - } + // Handle taskHint 'optional' without task augmentation - automatic polling + if (taskHint === 'optional' && !isTaskRequest && isTaskHandler) { + return await this.handleAutomaticTaskPolling(tool, request, extra); } + // Normal execution path + const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); + const result = await this.executeToolHandler(tool, args, extra); + + // Return CreateTaskResult immediately for task requests if (isTaskRequest) { - // Return the CreateTaskResult immediately return result; } - if (tool.outputSchema && !result.isError) { - if (!result.structuredContent) { - throw new McpError( - ErrorCode.InvalidParams, - `Output validation error: Tool ${request.params.name} has an output schema but no structured content was provided` - ); - } - - // if the tool has an output schema, validate structured content - const parseResult = await tool.outputSchema.safeParseAsync(result.structuredContent); - if (!parseResult.success) { - throw new McpError( - ErrorCode.InvalidParams, - `Output validation error: Invalid structured content for tool ${request.params.name}: ${parseResult.error.message}` - ); - } - } + // Validate output schema for non-task requests + await this.validateToolOutput(tool, result, request.params.name); + return result; } catch (error) { return this.createToolError(error instanceof Error ? error.message : String(error)); } - - return result; }); this._toolHandlersInitialized = true; @@ -239,6 +206,141 @@ export class McpServer { }; } + /** + * Validates tool input arguments against the tool's input schema. + */ + private async validateToolInput< + Tool extends RegisteredTool, + Args extends Tool['inputSchema'] extends infer InputSchema + ? InputSchema extends ZodType + ? z.infer + : undefined + : undefined + >(tool: Tool, args: Args, toolName: string): Promise { + if (!tool.inputSchema) { + return undefined as Args; + } + + const parseResult = await tool.inputSchema.safeParseAsync(args); + if (!parseResult.success) { + throw new McpError( + ErrorCode.InvalidParams, + `Input validation error: Invalid arguments for tool ${toolName}: ${parseResult.error.message}` + ); + } + + return parseResult.data as unknown as Args; + } + + /** + * Validates tool output against the tool's output schema. + */ + private async validateToolOutput(tool: RegisteredTool, result: CallToolResult | CreateTaskResult, toolName: string): Promise { + if (!tool.outputSchema) { + return; + } + + // Only validate CallToolResult, not CreateTaskResult + if (!('content' in result)) { + return; + } + + if (result.isError) { + return; + } + + if (!result.structuredContent) { + throw new McpError( + ErrorCode.InvalidParams, + `Output validation error: Tool ${toolName} has an output schema but no structured content was provided` + ); + } + + const parseResult = await tool.outputSchema.safeParseAsync(result.structuredContent); + if (!parseResult.success) { + throw new McpError( + ErrorCode.InvalidParams, + `Output validation error: Invalid structured content for tool ${toolName}: ${parseResult.error.message}` + ); + } + } + + /** + * Executes a tool handler (either regular or task-based). + */ + private async executeToolHandler( + tool: RegisteredTool, + args: unknown, + extra: RequestHandlerExtra + ): Promise { + const handler = tool.handler as AnyToolHandler; + const isTaskHandler = 'createTask' in handler; + + if (isTaskHandler) { + if (!extra.taskStore) { + throw new Error('No task store provided.'); + } + const taskExtra = { ...extra, taskStore: extra.taskStore }; + + if (tool.inputSchema) { + const typedHandler = handler as ToolTaskHandler; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return await Promise.resolve(typedHandler.createTask(args as any, taskExtra)); + } else { + const typedHandler = handler as ToolTaskHandler; + return await Promise.resolve(typedHandler.createTask(taskExtra)); + } + } + + if (tool.inputSchema) { + const typedHandler = handler as ToolCallback; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return await Promise.resolve(typedHandler(args as any, extra)); + } else { + const typedHandler = handler as ToolCallback; + return await Promise.resolve(typedHandler(extra)); + } + } + + /** + * Handles automatic task polling for tools with taskHint 'optional'. + */ + private async handleAutomaticTaskPolling( + tool: RegisteredTool, + request: RequestT, + extra: RequestHandlerExtra + ): Promise { + if (!extra.taskStore) { + throw new Error('No task store provided for task-capable tool.'); + } + + // Validate input and create task + const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); + const handler = tool.handler as ToolTaskHandler; + const taskExtra = { ...extra, taskStore: extra.taskStore }; + + const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined + ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, taskExtra)) + : await Promise.resolve((handler as ToolTaskHandler).createTask(taskExtra)); + + // Poll until completion + const taskId = createTaskResult.task.taskId; + let task = createTaskResult.task; + const pollInterval = task.pollInterval ?? 5000; + + while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') { + await new Promise(resolve => setTimeout(resolve, pollInterval)); + const updatedTask = await extra.taskStore.getTask(taskId); + if (!updatedTask) { + throw new McpError(ErrorCode.InternalError, `Task ${taskId} not found during polling`); + } + task = updatedTask; + } + + // Return the final result + return (await extra.taskStore.getTaskResult(taskId)) as CallToolResult; + } + private _completionHandlerInitialized = false; private setCompletionRequestHandler() { @@ -887,7 +989,7 @@ export class McpServer { description?: string; inputSchema?: InputArgs; outputSchema?: OutputArgs; - annotations?: ToolAnnotations; + annotations?: NoTaskToolAnnotations; _meta?: Record; }, handler: ToolTaskHandler @@ -898,7 +1000,7 @@ export class McpServer { config.description, config.inputSchema, config.outputSchema, - { ...config.annotations, taskHint: 'always' }, + { taskHint: 'always', ...config.annotations }, config._meta, handler ); @@ -1158,6 +1260,10 @@ export type AnyToolCallback = undefined> = ToolCallback | ToolTaskHandler; +export interface NoTaskToolAnnotations extends ToolAnnotations { + taskHint?: 'never'; +} + export type RegisteredTool = { title?: string; description?: string; From a3ebe62542730ca31456b2bccab95c074460811e Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 19 Nov 2025 12:51:22 -0800 Subject: [PATCH 43/84] Add tests for task cancellation vs request cancellation --- src/shared/protocol.test.ts | 293 ++++++++++++++++++++++++++++++++++++ 1 file changed, 293 insertions(+) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index d2c93937a..870c63fe6 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -1843,3 +1843,296 @@ describe('Task-based execution', () => { }); }); }); + +describe('Request Cancellation vs Task Cancellation', () => { + let protocol: Protocol; + let transport: MockTransport; + let taskStore: TaskStore; + + beforeEach(() => { + transport = new MockTransport(); + taskStore = createMockTaskStore(); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + }); + + describe('notifications/cancelled behavior', () => { + test('should abort request handler when notifications/cancelled is received', async () => { + await protocol.connect(transport); + + // Set up a request handler that checks if it was aborted + let wasAborted = false; + const TestRequestSchema = z.object({ + method: z.literal('test/longRunning'), + params: z.optional(z.record(z.unknown())) + }); + protocol.setRequestHandler(TestRequestSchema, async (_request, extra) => { + // Simulate a long-running operation + await new Promise(resolve => setTimeout(resolve, 100)); + wasAborted = extra.signal.aborted; + return { _meta: {} } as Result; + }); + + // Simulate an incoming request + const requestId = 123; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: requestId, + method: 'test/longRunning', + params: {} + }); + } + + // Wait a bit for the handler to start + await new Promise(resolve => setTimeout(resolve, 10)); + + // Send cancellation notification + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/cancelled', + params: { + requestId: requestId, + reason: 'User cancelled' + } + }); + } + + // Wait for the handler to complete + await new Promise(resolve => setTimeout(resolve, 150)); + + // Verify the request was aborted + expect(wasAborted).toBe(true); + }); + + test('should NOT automatically cancel associated tasks when notifications/cancelled is received', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + method: 'test/method', + params: {} + }); + + // Send cancellation notification for the request + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/cancelled', + params: { + requestId: 'req-1', + reason: 'User cancelled' + } + }); + } + + // Wait a bit + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify the task status was NOT changed to cancelled + const updatedTask = await taskStore.getTask(task.taskId); + expect(updatedTask?.status).toBe('working'); + expect(taskStore.updateTaskStatus).not.toHaveBeenCalledWith(task.taskId, 'cancelled', expect.any(String)); + }); + }); + + describe('tasks/cancel behavior', () => { + test('should cancel task independently of request cancellation', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + method: 'test/method', + params: {} + }); + + // Cancel the task using tasks/cancel + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 999, + method: 'tasks/cancel', + params: { + taskId: task.taskId + } + }); + } + + // Wait for the handler to complete + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify the task was cancelled + expect(taskStore.updateTaskStatus).toHaveBeenCalledWith( + task.taskId, + 'cancelled', + 'Client cancelled task execution.', + undefined + ); + }); + + test('should reject cancellation of terminal tasks', async () => { + await protocol.connect(transport); + const sendSpy = vi.spyOn(transport, 'send'); + + // Create a task and mark it as completed + const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + method: 'test/method', + params: {} + }); + await taskStore.updateTaskStatus(task.taskId, 'completed'); + + // Try to cancel the completed task + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 999, + method: 'tasks/cancel', + params: { + taskId: task.taskId + } + }); + } + + // Wait for the handler to complete + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify an error was sent + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + id: 999, + error: expect.objectContaining({ + code: ErrorCode.InvalidParams, + message: expect.stringContaining('Cannot cancel task in terminal status') + }) + }) + ); + }); + + test('should return error when task not found', async () => { + await protocol.connect(transport); + const sendSpy = vi.spyOn(transport, 'send'); + + // Try to cancel a non-existent task + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 999, + method: 'tasks/cancel', + params: { + taskId: 'non-existent-task' + } + }); + } + + // Wait for the handler to complete + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify an error was sent + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + id: 999, + error: expect.objectContaining({ + code: ErrorCode.InvalidParams, + message: expect.stringContaining('Task not found') + }) + }) + ); + }); + }); + + describe('separation of concerns', () => { + test('should allow request cancellation without affecting task', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + method: 'test/method', + params: {} + }); + + // Cancel the request (not the task) + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/cancelled', + params: { + requestId: 'req-1', + reason: 'User cancelled request' + } + }); + } + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify task is still working + const updatedTask = await taskStore.getTask(task.taskId); + expect(updatedTask?.status).toBe('working'); + }); + + test('should allow task cancellation without affecting request', async () => { + await protocol.connect(transport); + + // Set up a request handler + let requestCompleted = false; + const TestMethodSchema = z.object({ + method: z.literal('test/method'), + params: z.optional(z.record(z.unknown())) + }); + protocol.setRequestHandler(TestMethodSchema, async () => { + await new Promise(resolve => setTimeout(resolve, 50)); + requestCompleted = true; + return { _meta: {} } as Result; + }); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + method: 'test/method', + params: {} + }); + + // Start a request + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 123, + method: 'test/method', + params: {} + }); + } + + // Cancel the task (not the request) + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 999, + method: 'tasks/cancel', + params: { + taskId: task.taskId + } + }); + } + + // Wait for request to complete + await new Promise(resolve => setTimeout(resolve, 100)); + + // Verify request completed normally + expect(requestCompleted).toBe(true); + + // Verify task was cancelled + expect(taskStore.updateTaskStatus).toHaveBeenCalledWith( + task.taskId, + 'cancelled', + 'Client cancelled task execution.', + undefined + ); + }); + }); +}); From 901ec9c8d4b672139da826199736ed39099cf7f3 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 19 Nov 2025 14:49:41 -0800 Subject: [PATCH 44/84] Validate task status transitions; allow failure result --- src/client/index.test.ts | 18 +- src/examples/server/simpleStreamableHttp.ts | 2 +- src/examples/shared/inMemoryTaskStore.test.ts | 158 +++++++++++++++++- src/examples/shared/inMemoryTaskStore.ts | 18 +- src/server/index.test.ts | 16 +- src/server/mcp.test.ts | 16 +- src/shared/protocol.test.ts | 8 +- src/shared/protocol.ts | 18 +- src/shared/task.test.ts | 26 +++ src/shared/task.ts | 5 +- 10 files changed, 237 insertions(+), 48 deletions(-) create mode 100644 src/shared/task.test.ts diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 37539d807..9b97b8e3b 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -1382,7 +1382,7 @@ describe('Task-based execution', () => { content: [{ type: 'text', text: 'Tool executed successfully!' }] }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; } @@ -1468,7 +1468,7 @@ describe('Task-based execution', () => { content: [{ type: 'text', text: 'Success!' }] }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; } @@ -1554,7 +1554,7 @@ describe('Task-based execution', () => { content: [{ type: 'text', text: 'Result data!' }] }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; } @@ -1637,7 +1637,7 @@ describe('Task-based execution', () => { content: [{ type: 'text', text: 'Success!' }] }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; } @@ -1748,7 +1748,7 @@ describe('Task-based execution', () => { content: { username: 'list-user' } }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; }); @@ -1851,7 +1851,7 @@ describe('Task-based execution', () => { content: { username: 'list-user' } }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; }); @@ -1952,7 +1952,7 @@ describe('Task-based execution', () => { content: { username: 'result-user' } }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; }); @@ -2052,7 +2052,7 @@ describe('Task-based execution', () => { content: { username: 'list-user' } }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; }); @@ -2167,7 +2167,7 @@ describe('Task-based execution', () => { content: [{ type: 'text', text: `Result for ${request.params.arguments?.id || 'unknown'}` }] }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; } diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 043bda15e..51fb4451d 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -482,7 +482,7 @@ const getServer = () => { // Simulate out-of-band work (async () => { await new Promise(resolve => setTimeout(resolve, duration)); - await taskStore.storeTaskResult(taskId, { + await taskStore.storeTaskResult(taskId, 'completed', { content: [ { type: 'text', diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts index c994d54d9..627204142 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -134,6 +134,84 @@ describe('InMemoryTaskStore', () => { it('should throw if task not found', async () => { await expect(store.updateTaskStatus('non-existent', 'working')).rejects.toThrow('Task with ID non-existent not found'); }); + + describe('status lifecycle validation', () => { + it('should allow transition from working to input_required', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('input_required'); + }); + + it('should allow transition from working to completed', async () => { + await store.updateTaskStatus(taskId, 'completed'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('completed'); + }); + + it('should allow transition from working to failed', async () => { + await store.updateTaskStatus(taskId, 'failed'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('failed'); + }); + + it('should allow transition from working to cancelled', async () => { + await store.updateTaskStatus(taskId, 'cancelled'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('cancelled'); + }); + + it('should allow transition from input_required to working', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + await store.updateTaskStatus(taskId, 'working'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('working'); + }); + + it('should allow transition from input_required to completed', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + await store.updateTaskStatus(taskId, 'completed'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('completed'); + }); + + it('should allow transition from input_required to failed', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + await store.updateTaskStatus(taskId, 'failed'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('failed'); + }); + + it('should allow transition from input_required to cancelled', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + await store.updateTaskStatus(taskId, 'cancelled'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('cancelled'); + }); + + it('should reject transition from completed to any other status', async () => { + await store.updateTaskStatus(taskId, 'completed'); + await expect(store.updateTaskStatus(taskId, 'working')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'input_required')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'failed')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'cancelled')).rejects.toThrow('Cannot update task'); + }); + + it('should reject transition from failed to any other status', async () => { + await store.updateTaskStatus(taskId, 'failed'); + await expect(store.updateTaskStatus(taskId, 'working')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'input_required')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'completed')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'cancelled')).rejects.toThrow('Cannot update task'); + }); + + it('should reject transition from cancelled to any other status', async () => { + await store.updateTaskStatus(taskId, 'cancelled'); + await expect(store.updateTaskStatus(taskId, 'working')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'input_required')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'completed')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'failed')).rejects.toThrow('Cannot update task'); + }); + }); }); describe('storeTaskResult', () => { @@ -155,7 +233,7 @@ describe('InMemoryTaskStore', () => { content: [{ type: 'text' as const, text: 'Success!' }] }; - await store.storeTaskResult(taskId, result); + await store.storeTaskResult(taskId, 'completed', result); const task = await store.getTask(taskId); expect(task?.status).toBe('completed'); @@ -165,7 +243,79 @@ describe('InMemoryTaskStore', () => { }); it('should throw if task not found', async () => { - await expect(store.storeTaskResult('non-existent', {})).rejects.toThrow('Task with ID non-existent not found'); + await expect(store.storeTaskResult('non-existent', 'completed', {})).rejects.toThrow('Task with ID non-existent not found'); + }); + + it('should reject storing result for task already in completed status', async () => { + // First complete the task + const firstResult = { + content: [{ type: 'text' as const, text: 'First result' }] + }; + await store.storeTaskResult(taskId, 'completed', firstResult); + + // Try to store result again (should fail) + const secondResult = { + content: [{ type: 'text' as const, text: 'Second result' }] + }; + + await expect(store.storeTaskResult(taskId, 'completed', secondResult)).rejects.toThrow('Cannot store result for task'); + }); + + it('should store result with failed status', async () => { + const result = { + content: [{ type: 'text' as const, text: 'Error details' }], + isError: true + }; + + await store.storeTaskResult(taskId, 'failed', result); + + const task = await store.getTask(taskId); + expect(task?.status).toBe('failed'); + + const storedResult = await store.getTaskResult(taskId); + expect(storedResult).toEqual(result); + }); + + it('should reject storing result for task already in failed status', async () => { + // First fail the task + const firstResult = { + content: [{ type: 'text' as const, text: 'First error' }], + isError: true + }; + await store.storeTaskResult(taskId, 'failed', firstResult); + + // Try to store result again (should fail) + const secondResult = { + content: [{ type: 'text' as const, text: 'Second error' }], + isError: true + }; + + await expect(store.storeTaskResult(taskId, 'failed', secondResult)).rejects.toThrow('Cannot store result for task'); + }); + + it('should reject storing result for cancelled task', async () => { + // Mark task as cancelled + await store.updateTaskStatus(taskId, 'cancelled'); + + // Try to store result (should fail) + const result = { + content: [{ type: 'text' as const, text: 'Cancellation result' }] + }; + + await expect(store.storeTaskResult(taskId, 'completed', result)).rejects.toThrow('Cannot store result for task'); + }); + + it('should allow storing result from input_required status', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + + const result = { + content: [{ type: 'text' as const, text: 'Success!' }] + }; + + await store.storeTaskResult(taskId, 'completed', result); + + const task = await store.getTask(taskId); + expect(task?.status).toBe('completed'); }); }); @@ -194,7 +344,7 @@ describe('InMemoryTaskStore', () => { const result = { content: [{ type: 'text' as const, text: 'Result data' }] }; - await store.storeTaskResult(createdTask.taskId, result); + await store.storeTaskResult(createdTask.taskId, 'completed', result); const retrieved = await store.getTaskResult(createdTask.taskId); expect(retrieved).toEqual(result); @@ -244,7 +394,7 @@ describe('InMemoryTaskStore', () => { vi.advanceTimersByTime(500); // Store result (should reset timer) - await store.storeTaskResult(createdTask.taskId, { + await store.storeTaskResult(createdTask.taskId, 'completed', { content: [{ type: 'text' as const, text: 'Done' }] }); diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index b179e286a..0db79bb02 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -72,14 +72,21 @@ export class InMemoryTaskStore implements TaskStore { return stored ? { ...stored.task } : null; } - async storeTaskResult(taskId: string, result: Result, _sessionId?: string): Promise { + async storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result, _sessionId?: string): Promise { const stored = this.tasks.get(taskId); if (!stored) { throw new Error(`Task with ID ${taskId} not found`); } + // Don't allow storing results for tasks already in terminal state + if (isTerminal(stored.task.status)) { + throw new Error( + `Cannot store result for task ${taskId} in terminal status '${stored.task.status}'. Task results can only be stored once.` + ); + } + stored.result = result; - stored.task.status = 'completed'; + stored.task.status = status; // Reset cleanup timer to start from now (if ttl is set) if (stored.task.ttl) { @@ -116,6 +123,13 @@ export class InMemoryTaskStore implements TaskStore { throw new Error(`Task with ID ${taskId} not found`); } + // Don't allow transitions from terminal states + if (isTerminal(stored.task.status)) { + throw new Error( + `Cannot update task ${taskId} from terminal status '${stored.task.status}' to '${status}'. Terminal states (completed, failed, cancelled) cannot transition to other states.` + ); + } + stored.task.status = status; if (statusMessage) { stored.task.statusMessage = statusMessage; diff --git a/src/server/index.test.ts b/src/server/index.test.ts index d712797ff..049d97cac 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -999,7 +999,7 @@ describe('Task-based execution', () => { content: [{ type: 'text', text: 'Tool executed successfully!' }] }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; } @@ -1251,7 +1251,7 @@ describe('Task-based execution', () => { ] }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; } @@ -1357,7 +1357,7 @@ describe('Task-based execution', () => { content: { username: 'server-test-user', confirmed: true } }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; }); @@ -1448,7 +1448,7 @@ describe('Task-based execution', () => { content: { username: 'list-user' } }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; }); @@ -1536,7 +1536,7 @@ describe('Task-based execution', () => { content: { username: 'result-user', confirmed: true } }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; }); @@ -1626,7 +1626,7 @@ describe('Task-based execution', () => { content: { username: 'list-user' } }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; }); @@ -1744,7 +1744,7 @@ describe('Task-based execution', () => { content: [{ type: 'text', text: `Completed task ${request.params.arguments?.taskNum || 'unknown'}` }] }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; } @@ -1982,7 +1982,7 @@ test('should respect client task capabilities', async () => { content: { username: 'test-user' } }; if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, result); + await extra.taskStore.storeTaskResult(taskId, 'completed', result); } return result; }); diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 629874e29..2e5271d6c 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -4607,7 +4607,7 @@ describe('Tool-level task hints with automatic polling wrapper', () => { // Simulate async work setTimeout(async () => { - await store.storeTaskResult(task.taskId, { + await store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text' as const, text: `Processed: ${input}` }] }); }, 200); @@ -4715,7 +4715,7 @@ describe('Tool-level task hints with automatic polling wrapper', () => { // Simulate async work setTimeout(async () => { - await store.storeTaskResult(task.taskId, { + await store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text' as const, text: `Result: ${value * 2}` }] }); releaseLatch(); @@ -4826,7 +4826,7 @@ describe('Tool-level task hints with automatic polling wrapper', () => { // Simulate async work setTimeout(async () => { - await store.storeTaskResult(task.taskId, { + await store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text' as const, text: `Completed: ${data}` }] }); releaseLatch(); @@ -5117,8 +5117,7 @@ describe('Tool-level task hints with automatic polling wrapper', () => { // Simulate async failure setTimeout(async () => { - await store.updateTaskStatus(task.taskId, 'failed', 'Task failed'); - await store.storeTaskResult(task.taskId, { + await store.storeTaskResult(task.taskId, 'failed', { content: [{ type: 'text' as const, text: 'Error occurred' }], isError: true }); @@ -5228,9 +5227,6 @@ describe('Tool-level task hints with automatic polling wrapper', () => { // Simulate async cancellation setTimeout(async () => { await store.updateTaskStatus(task.taskId, 'cancelled', 'Task was cancelled'); - await store.storeTaskResult(task.taskId, { - content: [{ type: 'text' as const, text: 'Task cancelled' }] - }); releaseLatch(); }, 150); @@ -5263,9 +5259,9 @@ describe('Tool-level task hints with automatic polling wrapper', () => { CallToolResultSchema ); - // Should receive the cancellation result + // Should receive an error since cancelled tasks don't have results expect(result).toHaveProperty('content'); - expect(result.content).toEqual([{ type: 'text' as const, text: 'Task cancelled' }]); + expect(result.content).toEqual([{ type: 'text' as const, text: expect.stringContaining('has no result stored') }]); // Wait for async operations to complete await waitForLatch(); diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 870c63fe6..72ecc9b2b 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -62,12 +62,12 @@ function createMockTaskStore(options?: { } return Promise.resolve(); }), - storeTaskResult: vi.fn((taskId: string, result: Result) => { + storeTaskResult: vi.fn((taskId: string, status: 'completed' | 'failed', result: Result) => { const task = tasks[taskId]; if (task) { - task.status = 'completed'; + task.status = status; task.result = result; - options?.onStatus?.('completed'); + options?.onStatus?.(status); } return Promise.resolve(); }), @@ -1737,7 +1737,7 @@ describe('Task-based execution', () => { content: [{ type: 'text', text: 'test result' }] }; - await mockTaskStore.storeTaskResult(task.taskId, testResult); + await mockTaskStore.storeTaskResult(task.taskId, 'completed', testResult); const serverProtocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index ce457dbc4..91ca41ad9 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -179,12 +179,13 @@ export interface RequestTaskStore { getTask(taskId: string): Promise; /** - * Stores the result of a completed task. + * Stores the result of a task and sets its final status. * * @param taskId - The task identifier + * @param status - The final status: 'completed' for success, 'failed' for errors * @param result - The result to store */ - storeTaskResult(taskId: string, result: Result): Promise; + storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result): Promise; /** * Retrieves the stored result of a task. @@ -1153,8 +1154,8 @@ export abstract class Protocol { - await taskStore.storeTaskResult(taskId, result, sessionId); + storeTaskResult: async (taskId, status, result) => { + await taskStore.storeTaskResult(taskId, status, result, sessionId); // Get updated task state and send notification const task = await taskStore.getTask(taskId, sessionId); @@ -1173,17 +1174,18 @@ export abstract class Protocol { try { - // Check the current task status to avoid overwriting terminal states - // as a safeguard for when the TaskStore implementation doesn't try - // to avoid this. + // Check if task is in terminal state before attempting to update const task = await taskStore.getTask(taskId, sessionId); if (!task) { return; } + // Don't allow transitions from terminal states if (isTerminal(task.status)) { this._onerror( - new Error(`Failed to update status of task "${taskId}" from terminal status "${task.status}" to "${status}"`) + new Error( + `Cannot update task "${taskId}" from terminal status "${task.status}" to "${status}". Terminal states (completed, failed, cancelled) cannot transition to other states.` + ) ); return; } diff --git a/src/shared/task.test.ts b/src/shared/task.test.ts new file mode 100644 index 000000000..d33ca15cf --- /dev/null +++ b/src/shared/task.test.ts @@ -0,0 +1,26 @@ +import { describe, it, expect } from 'vitest'; +import { isTerminal } from './task.js'; + +describe('Task utility functions', () => { + describe('isTerminal', () => { + it('should return true for completed status', () => { + expect(isTerminal('completed')).toBe(true); + }); + + it('should return true for failed status', () => { + expect(isTerminal('failed')).toBe(true); + }); + + it('should return true for cancelled status', () => { + expect(isTerminal('cancelled')).toBe(true); + }); + + it('should return false for working status', () => { + expect(isTerminal('working')).toBe(false); + }); + + it('should return false for input_required status', () => { + expect(isTerminal('input_required')).toBe(false); + }); + }); +}); diff --git a/src/shared/task.ts b/src/shared/task.ts index fd32aa979..8c3d16cd8 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -29,13 +29,14 @@ export interface TaskStore { getTask(taskId: string, sessionId?: string): Promise; /** - * Stores the result of a completed task. + * Stores the result of a task and sets its final status. * * @param taskId - The task identifier + * @param status - The final status: 'completed' for success, 'failed' for errors * @param result - The result to store * @param sessionId - Optional session ID for binding the operation to a specific session */ - storeTaskResult(taskId: string, result: Result, sessionId?: string): Promise; + storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result, sessionId?: string): Promise; /** * Retrieves the stored result of a task. From 92dfcf883b2bcbc7eda6d8d182c29b7fb5734b6d Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 19 Nov 2025 15:07:36 -0800 Subject: [PATCH 45/84] Fix registerToolTask type inference with overloads --- src/server/mcp.ts | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/server/mcp.ts b/src/server/mcp.ts index aaaa5595c..d18ced7c8 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -982,7 +982,38 @@ export class McpServer { /** * Registers a task-based tool with a config object and callback. */ - registerToolTask( + registerToolTask>( + name: string, + config: { + title?: string; + description?: string; + outputSchema?: OutputArgs; + annotations?: NoTaskToolAnnotations; + _meta?: Record; + }, + handler: ToolTaskHandler + ): RegisteredTool; + + /** + * Registers a task-based tool with a config object and callback. + */ + registerToolTask, OutputArgs extends undefined | ZodRawShape | ZodType>( + name: string, + config: { + title?: string; + description?: string; + inputSchema: InputArgs; + outputSchema?: OutputArgs; + annotations?: NoTaskToolAnnotations; + _meta?: Record; + }, + handler: ToolTaskHandler + ): RegisteredTool; + + registerToolTask< + InputArgs extends undefined | ZodRawShape | ZodType, + OutputArgs extends undefined | ZodRawShape | ZodType + >( name: string, config: { title?: string; From 5af038dcaedafe5a3dd5a9c8fc062ba6386772d1 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 19 Nov 2025 15:16:39 -0800 Subject: [PATCH 46/84] Clean up TTL handling --- src/examples/shared/inMemoryTaskStore.test.ts | 79 +++++++++++++++++++ src/examples/shared/inMemoryTaskStore.ts | 9 ++- src/shared/task.ts | 7 ++ 3 files changed, 92 insertions(+), 3 deletions(-) diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts index 627204142..ea74a1460 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -463,6 +463,85 @@ describe('InMemoryTaskStore', () => { task = await store.getTask(createdTask2.taskId); expect(task).toBeNull(); }); + + it('should return actual TTL in task response', async () => { + // Test that the TaskStore returns the actual TTL it will use + // This implementation uses the requested TTL as-is, but implementations + // MAY override it (e.g., enforce maximum TTL limits) + const requestedTtl = 5000; + const taskParams: TaskCreationParams = { + ttl: requestedTtl + }; + const createdTask = await store.createTask(taskParams, 1111, { + method: 'tools/call', + params: {} + }); + + // The returned task should include the actual TTL that will be used + expect(createdTask.ttl).toBe(requestedTtl); + + // Verify the task is cleaned up after the actual TTL + vi.advanceTimersByTime(requestedTtl + 1); + const task = await store.getTask(createdTask.taskId); + expect(task).toBeNull(); + }); + + it('should support null TTL for unlimited lifetime', async () => { + // Test that null TTL means unlimited lifetime + const taskParams: TaskCreationParams = { + ttl: null + }; + const createdTask = await store.createTask(taskParams, 2222, { + method: 'tools/call', + params: {} + }); + + // The returned task should have null TTL + expect(createdTask.ttl).toBeNull(); + + // Task should not be cleaned up even after a long time + vi.advanceTimersByTime(100000); + const task = await store.getTask(createdTask.taskId); + expect(task).toBeDefined(); + expect(task?.taskId).toBe(createdTask.taskId); + }); + + it('should cleanup tasks regardless of status', async () => { + // Test that TTL cleanup happens regardless of task status + const taskParams: TaskCreationParams = { + ttl: 1000 + }; + + // Create tasks in different statuses + const workingTask = await store.createTask(taskParams, 3333, { + method: 'tools/call', + params: {} + }); + + const completedTask = await store.createTask(taskParams, 4444, { + method: 'tools/call', + params: {} + }); + await store.storeTaskResult(completedTask.taskId, 'completed', { + content: [{ type: 'text' as const, text: 'Done' }] + }); + + const failedTask = await store.createTask(taskParams, 5555, { + method: 'tools/call', + params: {} + }); + await store.storeTaskResult(failedTask.taskId, 'failed', { + content: [{ type: 'text' as const, text: 'Error' }] + }); + + // Fast-forward past TTL + vi.advanceTimersByTime(1001); + + // All tasks should be cleaned up regardless of status + expect(await store.getTask(workingTask.taskId)).toBeNull(); + expect(await store.getTask(completedTask.taskId)).toBeNull(); + expect(await store.getTask(failedTask.taskId)).toBeNull(); + }); }); describe('getAllTasks', () => { diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index 0db79bb02..3414d2bbd 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -39,11 +39,13 @@ export class InMemoryTaskStore implements TaskStore { throw new Error(`Task with ID ${taskId} already exists`); } + const actualTtl = taskParams.ttl ?? null; + // Create task with generated ID and timestamp const task: Task = { taskId, status: 'working', - ttl: taskParams.ttl ?? null, + ttl: actualTtl, createdAt: new Date().toISOString(), pollInterval: taskParams.pollInterval ?? 500 }; @@ -55,11 +57,12 @@ export class InMemoryTaskStore implements TaskStore { }); // Schedule cleanup if ttl is specified - if (taskParams.ttl) { + // Cleanup occurs regardless of task status + if (actualTtl) { const timer = setTimeout(() => { this.tasks.delete(taskId); this.cleanupTimers.delete(taskId); - }, taskParams.ttl); + }, actualTtl); this.cleanupTimers.set(taskId, timer); } diff --git a/src/shared/task.ts b/src/shared/task.ts index 8c3d16cd8..c7946006a 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -11,6 +11,13 @@ export interface TaskStore { * Creates a new task with the given creation parameters and original request. * The implementation must generate a unique taskId and createdAt timestamp. * + * TTL Management: + * - The implementation receives the TTL suggested by the requestor via taskParams.ttl + * - The implementation MAY override the requested TTL (e.g., to enforce limits) + * - The actual TTL used MUST be returned in the Task object + * - Null TTL indicates unlimited task lifetime (no automatic cleanup) + * - Cleanup SHOULD occur automatically after TTL expires, regardless of task status + * * @param taskParams - The task creation parameters from the request (ttl, pollInterval) * @param requestId - The JSON-RPC request ID * @param request - The original request that triggered task creation From 5996224eef9e638d7f51c016a21fe1f3d60f0dbd Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 19 Nov 2025 15:21:18 -0800 Subject: [PATCH 47/84] Add task-listing tests --- src/shared/task-listing.test.ts | 167 ++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 src/shared/task-listing.test.ts diff --git a/src/shared/task-listing.test.ts b/src/shared/task-listing.test.ts new file mode 100644 index 000000000..9651df23c --- /dev/null +++ b/src/shared/task-listing.test.ts @@ -0,0 +1,167 @@ +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; +import { InMemoryTransport } from '../inMemory.js'; +import { Client } from '../client/index.js'; +import { Server } from '../server/index.js'; +import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; + +describe('Task Listing with Pagination', () => { + let client: Client; + let server: Server; + let taskStore: InMemoryTaskStore; + let clientTransport: InMemoryTransport; + let serverTransport: InMemoryTransport; + + beforeEach(async () => { + taskStore = new InMemoryTaskStore(); + + [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + list: {}, + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + list: {}, + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + }); + + afterEach(async () => { + taskStore.cleanup(); + await client.close(); + await server.close(); + }); + + it('should return empty list when no tasks exist', async () => { + const result = await client.listTasks(); + + expect(result.tasks).toEqual([]); + expect(result.nextCursor).toBeUndefined(); + }); + + it('should return all tasks when less than page size', async () => { + // Create 3 tasks + for (let i = 0; i < 3; i++) { + await taskStore.createTask({}, i, { + method: 'tools/call', + params: { name: 'test-tool' } + }); + } + + const result = await client.listTasks(); + + expect(result.tasks).toHaveLength(3); + expect(result.nextCursor).toBeUndefined(); + }); + + it('should paginate when more than page size exists', async () => { + // Create 15 tasks (page size is 10 in InMemoryTaskStore) + for (let i = 0; i < 15; i++) { + await taskStore.createTask({}, i, { + method: 'tools/call', + params: { name: 'test-tool' } + }); + } + + // Get first page + const page1 = await client.listTasks(); + expect(page1.tasks).toHaveLength(10); + expect(page1.nextCursor).toBeDefined(); + + // Get second page using cursor + const page2 = await client.listTasks({ cursor: page1.nextCursor }); + expect(page2.tasks).toHaveLength(5); + expect(page2.nextCursor).toBeUndefined(); + }); + + it('should treat cursor as opaque token', async () => { + // Create 5 tasks + for (let i = 0; i < 5; i++) { + await taskStore.createTask({}, i, { + method: 'tools/call', + params: { name: 'test-tool' } + }); + } + + // Get all tasks to get a valid cursor + const allTasks = taskStore.getAllTasks(); + const validCursor = allTasks[2].taskId; + + // Use the cursor - should work even though we don't know its internal structure + const result = await client.listTasks({ cursor: validCursor }); + expect(result.tasks).toHaveLength(2); + }); + + it('should return error for invalid cursor', async () => { + await taskStore.createTask({}, 1, { + method: 'tools/call', + params: { name: 'test-tool' } + }); + + // Try to use an invalid cursor + await expect(client.listTasks({ cursor: 'invalid-cursor' })).rejects.toThrow(); + }); + + it('should ensure tasks accessible via tasks/get are also accessible via tasks/list', async () => { + // Create a task + const task = await taskStore.createTask({}, 1, { + method: 'tools/call', + params: { name: 'test-tool' } + }); + + // Verify it's accessible via tasks/get + const getResult = await client.getTask({ taskId: task.taskId }); + expect(getResult.taskId).toBe(task.taskId); + + // Verify it's also accessible via tasks/list + const listResult = await client.listTasks(); + expect(listResult.tasks).toHaveLength(1); + expect(listResult.tasks[0].taskId).toBe(task.taskId); + }); + + it('should not include related-task metadata in list response', async () => { + // Create a task + await taskStore.createTask({}, 1, { + method: 'tools/call', + params: { name: 'test-tool' } + }); + + const result = await client.listTasks(); + + // The response should have _meta but not include related-task metadata + expect(result._meta).toBeDefined(); + expect(result._meta?.['io.modelcontextprotocol/related-task']).toBeUndefined(); + }); +}); From b1e140119dce22ebc4b379c118be1aa137a29488 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 19 Nov 2025 17:04:59 -0800 Subject: [PATCH 48/84] Implement cross-request progress --- src/shared/protocol.test.ts | 499 ++++++++++++++++++++++++++++++++++++ src/shared/protocol.ts | 52 +++- 2 files changed, 546 insertions(+), 5 deletions(-) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 72ecc9b2b..9a3795d29 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -2135,4 +2135,503 @@ describe('Request Cancellation vs Task Cancellation', () => { ); }); }); + + describe('progress notification support for tasks', () => { + it('should maintain progress token association after CreateTaskResult is returned', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + // Start a task-augmented request with progress callback + protocol.beginRequest(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); + + // Get the message ID from the sent request + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + expect(progressToken).toBe(messageId); + + // Simulate CreateTaskResult response + const taskId = 'test-task-123'; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString() + } + } + }); + } + + // Wait for response to be processed + await Promise.resolve(); + await Promise.resolve(); + + // Send a progress notification - should still work after CreateTaskResult + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 50, + total: 100 + } + }); + } + + // Wait for notification to be processed + await Promise.resolve(); + + // Verify progress callback was invoked + expect(progressCallback).toHaveBeenCalledWith({ + progress: 50, + total: 100 + }); + }); + + it('should stop progress notifications when task reaches terminal status (completed)', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + // Set up a request handler that will complete the task + protocol.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (extra.taskStore) { + const task = await extra.taskStore.createTask({ ttl: 60000 }, extra.requestId, request); + + // Simulate async work then complete the task + setTimeout(async () => { + await extra.taskStore!.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: 'Done' }] + }); + }, 50); + + return { task }; + } + return { content: [] }; + }); + + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + // Start a task-augmented request with progress callback + protocol.beginRequest(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); + + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + // Create a task in the mock store first so it exists when we try to get it later + const createdTask = await taskStore.createTask({ ttl: 60000 }, messageId, request); + const taskId = createdTask.taskId; + + // Simulate CreateTaskResult response + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: createdTask + } + }); + } + + await Promise.resolve(); + await Promise.resolve(); + + // Progress notification should work while task is working + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 50, + total: 100 + } + }); + } + + await Promise.resolve(); + + expect(progressCallback).toHaveBeenCalledTimes(1); + + // Verify the task-progress association was created + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const taskProgressTokens = (protocol as any)._taskProgressTokens as Map; + expect(taskProgressTokens.has(taskId)).toBe(true); + expect(taskProgressTokens.get(taskId)).toBe(progressToken); + + // Simulate task completion by calling through the protocol's task store + // This will trigger the cleanup logic + const mockRequest = { jsonrpc: '2.0' as const, id: 999, method: 'test', params: {} }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const requestTaskStore = (protocol as any).requestTaskStore(mockRequest, undefined); + await requestTaskStore.storeTaskResult(taskId, 'completed', { content: [] }); + + // Wait for all async operations including notification sending to complete + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify the association was cleaned up + expect(taskProgressTokens.has(taskId)).toBe(false); + + // Try to send progress notification after task completion - should be ignored + progressCallback.mockClear(); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 100, + total: 100 + } + }); + } + + await Promise.resolve(); + + // Progress callback should NOT be invoked after task completion + expect(progressCallback).not.toHaveBeenCalled(); + }); + + it('should stop progress notifications when task reaches terminal status (failed)', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + protocol.beginRequest(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); + + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + // Simulate CreateTaskResult response + const taskId = 'test-task-456'; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString() + } + } + }); + } + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Simulate task failure via storeTaskResult + await taskStore.storeTaskResult(taskId, 'failed', { + content: [], + isError: true + }); + + // Manually trigger the status notification + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/tasks/status', + params: { + task: { + taskId, + status: 'failed', + ttl: 60000, + createdAt: new Date().toISOString(), + statusMessage: 'Task failed' + } + } + }); + } + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Try to send progress notification after task failure - should be ignored + progressCallback.mockClear(); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 75, + total: 100 + } + }); + } + + expect(progressCallback).not.toHaveBeenCalled(); + }); + + it('should stop progress notifications when task is cancelled', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + protocol.beginRequest(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); + + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + // Simulate CreateTaskResult response + const taskId = 'test-task-789'; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString() + } + } + }); + } + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Simulate task cancellation via updateTaskStatus + await taskStore.updateTaskStatus(taskId, 'cancelled', 'User cancelled'); + + // Manually trigger the status notification + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/tasks/status', + params: { + task: { + taskId, + status: 'cancelled', + ttl: 60000, + createdAt: new Date().toISOString(), + statusMessage: 'User cancelled' + } + } + }); + } + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Try to send progress notification after cancellation - should be ignored + progressCallback.mockClear(); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 25, + total: 100 + } + }); + } + + expect(progressCallback).not.toHaveBeenCalled(); + }); + + it('should use the same progressToken throughout task lifetime', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + protocol.beginRequest(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); + + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + // Simulate CreateTaskResult response + const taskId = 'test-task-consistency'; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString() + } + } + }); + } + + await Promise.resolve(); + await Promise.resolve(); + + // Send multiple progress notifications with the same token + const progressUpdates = [ + { progress: 25, total: 100 }, + { progress: 50, total: 100 }, + { progress: 75, total: 100 } + ]; + + for (const update of progressUpdates) { + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, // Same token for all notifications + ...update + } + }); + } + await Promise.resolve(); + } + + // Verify all progress notifications were received with the same token + expect(progressCallback).toHaveBeenCalledTimes(3); + expect(progressCallback).toHaveBeenNthCalledWith(1, { progress: 25, total: 100 }); + expect(progressCallback).toHaveBeenNthCalledWith(2, { progress: 50, total: 100 }); + expect(progressCallback).toHaveBeenNthCalledWith(3, { progress: 75, total: 100 }); + }); + }); }); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 91ca41ad9..c5fa885c5 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -94,6 +94,8 @@ export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60000; export type RequestOptions = { /** * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. + * + * For task-augmented requests: progress notifications continue after CreateTaskResult is returned and stop automatically when the task reaches a terminal status. */ onprogress?: ProgressCallback; @@ -307,6 +309,9 @@ export abstract class Protocol = new Map(); private _pendingDebouncedNotifications = new Set(); + // Maps task IDs to progress tokens to keep handlers alive after CreateTaskResult + private _taskProgressTokens: Map = new Map(); + private _taskStore?: TaskStore; /** @@ -361,7 +366,7 @@ export abstract class Protocol = { @@ -634,7 +640,7 @@ export abstract class Protocol { - // Include related-task metadata if this request is part of a task (Requirement 6.1) + // Include related-task metadata if this request is part of a task const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; if (relatedTaskId) { notificationOptions.relatedTask = { taskId: relatedTaskId }; @@ -642,7 +648,7 @@ export abstract class Protocol { - // Include related-task metadata if this request is part of a task (Requirement 6.1) + // Include related-task metadata if this request is part of a task const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; if (relatedTaskId && !requestOptions.relatedTask) { requestOptions.relatedTask = { taskId: relatedTaskId }; @@ -741,9 +747,25 @@ export abstract class Protocol; + if (result.task && typeof result.task === 'object') { + const task = result.task as Record; + if (typeof task.taskId === 'string') { + isTaskResponse = true; + this._taskProgressTokens.set(task.taskId, messageId); + } + } + } + + if (!isTaskResponse) { + this._progressHandlers.delete(messageId); + } + if (isJSONRPCResponse(response)) { handler(response); } else { @@ -1124,6 +1146,18 @@ export abstract class Protocol { @@ -1202,6 +1240,10 @@ export abstract class Protocol Date: Wed, 19 Nov 2025 17:37:13 -0800 Subject: [PATCH 49/84] Backfill unit tests for task changes --- src/shared/protocol.test.ts | 994 +++++++++++++++++++++--------------- src/shared/task.test.ts | 67 +++ 2 files changed, 662 insertions(+), 399 deletions(-) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 9a3795d29..bf7d4dcfd 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -2135,503 +2135,699 @@ describe('Request Cancellation vs Task Cancellation', () => { ); }); }); +}); - describe('progress notification support for tasks', () => { - it('should maintain progress token association after CreateTaskResult is returned', async () => { - const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); +describe('Progress notification support for tasks', () => { + let protocol: Protocol; + let transport: MockTransport; + let sendSpy: MockInstance; - const transport = new MockTransport(); - const sendSpy = vi.spyOn(transport, 'send'); - await protocol.connect(transport); + beforeEach(() => { + transport = new MockTransport(); + sendSpy = vi.spyOn(transport, 'send'); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + }); - const progressCallback = vi.fn(); - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; + it('should maintain progress token association after CreateTaskResult is returned', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); - const resultSchema = z.object({ - task: z.object({ - taskId: z.string(), - status: z.string(), - ttl: z.number().nullable(), - createdAt: z.string() - }) - }); + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); - // Start a task-augmented request with progress callback - protocol.beginRequest(request, resultSchema, { - task: { ttl: 60000 }, - onprogress: progressCallback - }); + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; - // Get the message ID from the sent request - const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; - const messageId = sentRequest.id; - const progressToken = sentRequest.params._meta.progressToken; + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); - expect(progressToken).toBe(messageId); + // Start a task-augmented request with progress callback + protocol.beginRequest(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); - // Simulate CreateTaskResult response - const taskId = 'test-task-123'; - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: messageId, - result: { - task: { - taskId, - status: 'working', - ttl: 60000, - createdAt: new Date().toISOString() - } - } - }); - } + // Get the message ID from the sent request + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; - // Wait for response to be processed - await Promise.resolve(); - await Promise.resolve(); + expect(progressToken).toBe(messageId); - // Send a progress notification - should still work after CreateTaskResult - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, - progress: 50, - total: 100 + // Simulate CreateTaskResult response + const taskId = 'test-task-123'; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString() } - }); - } + } + }); + } - // Wait for notification to be processed - await Promise.resolve(); + // Wait for response to be processed + await Promise.resolve(); + await Promise.resolve(); - // Verify progress callback was invoked - expect(progressCallback).toHaveBeenCalledWith({ - progress: 50, - total: 100 + // Send a progress notification - should still work after CreateTaskResult + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 50, + total: 100 + } }); + } + + // Wait for notification to be processed + await Promise.resolve(); + + // Verify progress callback was invoked + expect(progressCallback).toHaveBeenCalledWith({ + progress: 50, + total: 100 + }); + }); + + it('should stop progress notifications when task reaches terminal status (completed)', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + // Set up a request handler that will complete the task + protocol.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (extra.taskStore) { + const task = await extra.taskStore.createTask({ ttl: 60000 }, extra.requestId, request); + + // Simulate async work then complete the task + setTimeout(async () => { + await extra.taskStore!.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: 'Done' }] + }); + }, 50); + + return { task }; + } + return { content: [] }; }); - it('should stop progress notifications when task reaches terminal status (completed)', async () => { - const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; - const transport = new MockTransport(); - const sendSpy = vi.spyOn(transport, 'send'); - await protocol.connect(transport); + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + // Start a task-augmented request with progress callback + protocol.beginRequest(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); - // Set up a request handler that will complete the task - protocol.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - if (extra.taskStore) { - const task = await extra.taskStore.createTask({ ttl: 60000 }, extra.requestId, request); + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; - // Simulate async work then complete the task - setTimeout(async () => { - await extra.taskStore!.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: 'Done' }] - }); - }, 50); + // Create a task in the mock store first so it exists when we try to get it later + const createdTask = await taskStore.createTask({ ttl: 60000 }, messageId, request); + const taskId = createdTask.taskId; - return { task }; + // Simulate CreateTaskResult response + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: createdTask } - return { content: [] }; }); + } - const progressCallback = vi.fn(); - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; + await Promise.resolve(); + await Promise.resolve(); - const resultSchema = z.object({ - task: z.object({ - taskId: z.string(), - status: z.string(), - ttl: z.number().nullable(), - createdAt: z.string() - }) + // Progress notification should work while task is working + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 50, + total: 100 + } }); + } - // Start a task-augmented request with progress callback - protocol.beginRequest(request, resultSchema, { - task: { ttl: 60000 }, - onprogress: progressCallback - }); + await Promise.resolve(); - const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; - const messageId = sentRequest.id; - const progressToken = sentRequest.params._meta.progressToken; + expect(progressCallback).toHaveBeenCalledTimes(1); - // Create a task in the mock store first so it exists when we try to get it later - const createdTask = await taskStore.createTask({ ttl: 60000 }, messageId, request); - const taskId = createdTask.taskId; + // Verify the task-progress association was created + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const taskProgressTokens = (protocol as any)._taskProgressTokens as Map; + expect(taskProgressTokens.has(taskId)).toBe(true); + expect(taskProgressTokens.get(taskId)).toBe(progressToken); - // Simulate CreateTaskResult response - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: messageId, - result: { - task: createdTask - } - }); - } + // Simulate task completion by calling through the protocol's task store + // This will trigger the cleanup logic + const mockRequest = { jsonrpc: '2.0' as const, id: 999, method: 'test', params: {} }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const requestTaskStore = (protocol as any).requestTaskStore(mockRequest, undefined); + await requestTaskStore.storeTaskResult(taskId, 'completed', { content: [] }); - await Promise.resolve(); - await Promise.resolve(); + // Wait for all async operations including notification sending to complete + await new Promise(resolve => setTimeout(resolve, 50)); - // Progress notification should work while task is working - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, - progress: 50, - total: 100 - } - }); - } + // Verify the association was cleaned up + expect(taskProgressTokens.has(taskId)).toBe(false); - await Promise.resolve(); + // Try to send progress notification after task completion - should be ignored + progressCallback.mockClear(); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 100, + total: 100 + } + }); + } - expect(progressCallback).toHaveBeenCalledTimes(1); + await Promise.resolve(); - // Verify the task-progress association was created - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const taskProgressTokens = (protocol as any)._taskProgressTokens as Map; - expect(taskProgressTokens.has(taskId)).toBe(true); - expect(taskProgressTokens.get(taskId)).toBe(progressToken); + // Progress callback should NOT be invoked after task completion + expect(progressCallback).not.toHaveBeenCalled(); + }); - // Simulate task completion by calling through the protocol's task store - // This will trigger the cleanup logic - const mockRequest = { jsonrpc: '2.0' as const, id: 999, method: 'test', params: {} }; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const requestTaskStore = (protocol as any).requestTaskStore(mockRequest, undefined); - await requestTaskStore.storeTaskResult(taskId, 'completed', { content: [] }); + it('should stop progress notifications when task reaches terminal status (failed)', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); - // Wait for all async operations including notification sending to complete - await new Promise(resolve => setTimeout(resolve, 50)); + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); - // Verify the association was cleaned up - expect(taskProgressTokens.has(taskId)).toBe(false); + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; - // Try to send progress notification after task completion - should be ignored - progressCallback.mockClear(); - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, - progress: 100, - total: 100 + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + protocol.beginRequest(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); + + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + // Simulate CreateTaskResult response + const taskId = 'test-task-456'; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString() } - }); - } + } + }); + } - await Promise.resolve(); + await new Promise(resolve => setTimeout(resolve, 10)); - // Progress callback should NOT be invoked after task completion - expect(progressCallback).not.toHaveBeenCalled(); + // Simulate task failure via storeTaskResult + await taskStore.storeTaskResult(taskId, 'failed', { + content: [], + isError: true }); - it('should stop progress notifications when task reaches terminal status (failed)', async () => { - const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); - - const transport = new MockTransport(); - const sendSpy = vi.spyOn(transport, 'send'); - await protocol.connect(transport); + // Manually trigger the status notification + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/tasks/status', + params: { + task: { + taskId, + status: 'failed', + ttl: 60000, + createdAt: new Date().toISOString(), + statusMessage: 'Task failed' + } + } + }); + } - const progressCallback = vi.fn(); - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; + await new Promise(resolve => setTimeout(resolve, 10)); - const resultSchema = z.object({ - task: z.object({ - taskId: z.string(), - status: z.string(), - ttl: z.number().nullable(), - createdAt: z.string() - }) + // Try to send progress notification after task failure - should be ignored + progressCallback.mockClear(); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 75, + total: 100 + } }); + } - protocol.beginRequest(request, resultSchema, { - task: { ttl: 60000 }, - onprogress: progressCallback + expect(progressCallback).not.toHaveBeenCalled(); + }); + + it('should stop progress notifications when task is cancelled', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + protocol.beginRequest(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); + + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + // Simulate CreateTaskResult response + const taskId = 'test-task-789'; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString() + } + } }); + } - const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; - const messageId = sentRequest.id; - const progressToken = sentRequest.params._meta.progressToken; + await new Promise(resolve => setTimeout(resolve, 10)); - // Simulate CreateTaskResult response - const taskId = 'test-task-456'; - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: messageId, - result: { - task: { - taskId, - status: 'working', - ttl: 60000, - createdAt: new Date().toISOString() - } + // Simulate task cancellation via updateTaskStatus + await taskStore.updateTaskStatus(taskId, 'cancelled', 'User cancelled'); + + // Manually trigger the status notification + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/tasks/status', + params: { + task: { + taskId, + status: 'cancelled', + ttl: 60000, + createdAt: new Date().toISOString(), + statusMessage: 'User cancelled' } - }); - } + } + }); + } - await new Promise(resolve => setTimeout(resolve, 10)); + await new Promise(resolve => setTimeout(resolve, 10)); - // Simulate task failure via storeTaskResult - await taskStore.storeTaskResult(taskId, 'failed', { - content: [], - isError: true + // Try to send progress notification after cancellation - should be ignored + progressCallback.mockClear(); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 25, + total: 100 + } }); + } - // Manually trigger the status notification - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/tasks/status', - params: { - task: { - taskId, - status: 'failed', - ttl: 60000, - createdAt: new Date().toISOString(), - statusMessage: 'Task failed' - } + expect(progressCallback).not.toHaveBeenCalled(); + }); + + it('should use the same progressToken throughout task lifetime', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + protocol.beginRequest(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); + + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + // Simulate CreateTaskResult response + const taskId = 'test-task-consistency'; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString() } - }); - } + } + }); + } - await new Promise(resolve => setTimeout(resolve, 10)); + await Promise.resolve(); + await Promise.resolve(); + + // Send multiple progress notifications with the same token + const progressUpdates = [ + { progress: 25, total: 100 }, + { progress: 50, total: 100 }, + { progress: 75, total: 100 } + ]; - // Try to send progress notification after task failure - should be ignored - progressCallback.mockClear(); + for (const update of progressUpdates) { if (transport.onmessage) { transport.onmessage({ jsonrpc: '2.0', method: 'notifications/progress', params: { - progressToken, - progress: 75, - total: 100 + progressToken, // Same token for all notifications + ...update } }); } + await Promise.resolve(); + } + + // Verify all progress notifications were received with the same token + expect(progressCallback).toHaveBeenCalledTimes(3); + expect(progressCallback).toHaveBeenNthCalledWith(1, { progress: 25, total: 100 }); + expect(progressCallback).toHaveBeenNthCalledWith(2, { progress: 50, total: 100 }); + expect(progressCallback).toHaveBeenNthCalledWith(3, { progress: 75, total: 100 }); + }); + + it('should maintain progressToken throughout task lifetime', async () => { + await protocol.connect(transport); - expect(progressCallback).not.toHaveBeenCalled(); + const request = { + method: 'tools/call', + params: { name: 'long-running-tool' } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) }); - it('should stop progress notifications when task is cancelled', async () => { - const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const onProgressMock = vi.fn(); - const transport = new MockTransport(); - const sendSpy = vi.spyOn(transport, 'send'); - await protocol.connect(transport); + protocol.beginRequest(request, resultSchema, { + task: { + ttl: 60000 + }, + onprogress: onProgressMock + }); - const progressCallback = vi.fn(); - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.params._meta.progressToken).toBeDefined(); + }); - const resultSchema = z.object({ - task: z.object({ - taskId: z.string(), - status: z.string(), - ttl: z.number().nullable(), - createdAt: z.string() - }) - }); + it('should support progress notifications with task-augmented requests', async () => { + await protocol.connect(transport); - protocol.beginRequest(request, resultSchema, { - task: { ttl: 60000 }, - onprogress: progressCallback - }); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; - const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; - const messageId = sentRequest.id; - const progressToken = sentRequest.params._meta.progressToken; + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); - // Simulate CreateTaskResult response - const taskId = 'test-task-789'; - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: messageId, - result: { - task: { - taskId, - status: 'working', - ttl: 60000, - createdAt: new Date().toISOString() - } - } - }); - } + const onProgressMock = vi.fn(); - await new Promise(resolve => setTimeout(resolve, 10)); + protocol.beginRequest(request, resultSchema, { + task: { + ttl: 30000 + }, + onprogress: onProgressMock + }); - // Simulate task cancellation via updateTaskStatus - await taskStore.updateTaskStatus(taskId, 'cancelled', 'User cancelled'); + const sentMessage = sendSpy.mock.calls[0][0]; + const progressToken = sentMessage.params._meta.progressToken; - // Manually trigger the status notification - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/tasks/status', - params: { - task: { - taskId, - status: 'cancelled', - ttl: 60000, - createdAt: new Date().toISOString(), - statusMessage: 'User cancelled' - } - } - }); + // Simulate progress notification + transport.onmessage?.({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 50, + total: 100, + message: 'Processing...' } + }); - await new Promise(resolve => setTimeout(resolve, 10)); + await new Promise(resolve => setTimeout(resolve, 10)); - // Try to send progress notification after cancellation - should be ignored - progressCallback.mockClear(); - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, - progress: 25, - total: 100 - } - }); - } + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 50, + total: 100, + message: 'Processing...' + }); + }); + + it('should continue progress notifications after CreateTaskResult', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; - expect(progressCallback).not.toHaveBeenCalled(); + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) }); - it('should use the same progressToken throughout task lifetime', async () => { - const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const onProgressMock = vi.fn(); - const transport = new MockTransport(); - const sendSpy = vi.spyOn(transport, 'send'); - await protocol.connect(transport); + protocol.beginRequest(request, resultSchema, { + task: { + ttl: 30000 + }, + onprogress: onProgressMock + }); - const progressCallback = vi.fn(); - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; + const sentMessage = sendSpy.mock.calls[0][0]; + const progressToken = sentMessage.params._meta.progressToken; - const resultSchema = z.object({ - task: z.object({ - taskId: z.string(), - status: z.string(), - ttl: z.number().nullable(), - createdAt: z.string() - }) + // Simulate CreateTaskResult response + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: sentMessage.id, + result: { + task: { + taskId: 'task-123', + status: 'working', + ttl: 30000, + createdAt: new Date().toISOString() + } + } }); + }, 5); - protocol.beginRequest(request, resultSchema, { - task: { ttl: 60000 }, - onprogress: progressCallback + // Progress notifications should still work + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 75, + total: 100 + } }); + }, 10); - const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; - const messageId = sentRequest.id; - const progressToken = sentRequest.params._meta.progressToken; + await new Promise(resolve => setTimeout(resolve, 20)); - // Simulate CreateTaskResult response - const taskId = 'test-task-consistency'; - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: messageId, - result: { - task: { - taskId, - status: 'working', - ttl: 60000, - createdAt: new Date().toISOString() - } + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 75, + total: 100 + }); + }); +}); + +describe('Capability negotiation for tasks', () => { + it('should use empty objects for capability fields', () => { + const serverCapabilities = { + tasks: { + list: {}, + cancel: {}, + requests: { + tools: { + call: {} } - }); + } } + }; - await Promise.resolve(); - await Promise.resolve(); + expect(serverCapabilities.tasks.list).toEqual({}); + expect(serverCapabilities.tasks.cancel).toEqual({}); + expect(serverCapabilities.tasks.requests.tools.call).toEqual({}); + }); + + it('should include list and cancel in server capabilities', () => { + const serverCapabilities = { + tasks: { + list: {}, + cancel: {} + } + }; - // Send multiple progress notifications with the same token - const progressUpdates = [ - { progress: 25, total: 100 }, - { progress: 50, total: 100 }, - { progress: 75, total: 100 } - ]; + expect('list' in serverCapabilities.tasks).toBe(true); + expect('cancel' in serverCapabilities.tasks).toBe(true); + }); - for (const update of progressUpdates) { - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, // Same token for all notifications - ...update - } - }); - } - await Promise.resolve(); + it('should include list and cancel in client capabilities', () => { + const clientCapabilities = { + tasks: { + list: {}, + cancel: {} } + }; - // Verify all progress notifications were received with the same token - expect(progressCallback).toHaveBeenCalledTimes(3); - expect(progressCallback).toHaveBeenNthCalledWith(1, { progress: 25, total: 100 }); - expect(progressCallback).toHaveBeenNthCalledWith(2, { progress: 50, total: 100 }); - expect(progressCallback).toHaveBeenNthCalledWith(3, { progress: 75, total: 100 }); - }); + expect('list' in clientCapabilities.tasks).toBe(true); + expect('cancel' in clientCapabilities.tasks).toBe(true); }); }); diff --git a/src/shared/task.test.ts b/src/shared/task.test.ts index d33ca15cf..4d3843740 100644 --- a/src/shared/task.test.ts +++ b/src/shared/task.test.ts @@ -1,5 +1,6 @@ import { describe, it, expect } from 'vitest'; import { isTerminal } from './task.js'; +import type { Task } from '../types.js'; describe('Task utility functions', () => { describe('isTerminal', () => { @@ -24,3 +25,69 @@ describe('Task utility functions', () => { }); }); }); + +describe('Task Schema Validation', () => { + it('should validate task with ttl field', () => { + const task: Task = { + taskId: 'test-123', + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString(), + pollInterval: 1000 + }; + + expect(task.ttl).toBe(60000); + expect(task.createdAt).toBeDefined(); + expect(typeof task.createdAt).toBe('string'); + }); + + it('should validate task with null ttl', () => { + const task: Task = { + taskId: 'test-456', + status: 'completed', + ttl: null, + createdAt: new Date().toISOString() + }; + + expect(task.ttl).toBeNull(); + }); + + it('should validate task with statusMessage field', () => { + const task: Task = { + taskId: 'test-789', + status: 'failed', + ttl: null, + createdAt: new Date().toISOString(), + statusMessage: 'Operation failed due to timeout' + }; + + expect(task.statusMessage).toBe('Operation failed due to timeout'); + }); + + it('should validate task with createdAt in ISO 8601 format', () => { + const now = new Date(); + const task: Task = { + taskId: 'test-iso', + status: 'working', + ttl: 30000, + createdAt: now.toISOString() + }; + + expect(task.createdAt).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/); + expect(new Date(task.createdAt).getTime()).toBe(now.getTime()); + }); + + it('should validate all task statuses', () => { + const statuses: Task['status'][] = ['working', 'input_required', 'completed', 'failed', 'cancelled']; + + statuses.forEach(status => { + const task: Task = { + taskId: `test-${status}`, + status, + ttl: null, + createdAt: new Date().toISOString() + }; + expect(task.status).toBe(status); + }); + }); +}); From 370fd07ff965a1030d7a95c3aaeaa42c40075dd5 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 19 Nov 2025 18:22:34 -0800 Subject: [PATCH 50/84] Implement end to end tests for tasks --- src/integration-tests/taskLifecycle.test.ts | 704 ++++++++++++++++++++ 1 file changed, 704 insertions(+) create mode 100644 src/integration-tests/taskLifecycle.test.ts diff --git a/src/integration-tests/taskLifecycle.test.ts b/src/integration-tests/taskLifecycle.test.ts new file mode 100644 index 000000000..93e6ef8c9 --- /dev/null +++ b/src/integration-tests/taskLifecycle.test.ts @@ -0,0 +1,704 @@ +import { createServer, type Server } from 'node:http'; +import { AddressInfo } from 'node:net'; +import { randomUUID } from 'node:crypto'; +import { Client } from '../client/index.js'; +import { StreamableHTTPClientTransport } from '../client/streamableHttp.js'; +import { McpServer } from '../server/mcp.js'; +import { StreamableHTTPServerTransport } from '../server/streamableHttp.js'; +import { TaskSchema } from '../types.js'; +import { z } from 'zod'; +import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; + +describe('Task Lifecycle Integration Tests', () => { + let server: Server; + let mcpServer: McpServer; + let serverTransport: StreamableHTTPServerTransport; + let baseUrl: URL; + let taskStore: InMemoryTaskStore; + + beforeEach(async () => { + // Create task store + taskStore = new InMemoryTaskStore(); + + // Create MCP server with task support + mcpServer = new McpServer( + { name: 'test-server', version: '1.0.0' }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + }, + list: {}, + cancel: {} + } + }, + taskStore + } + ); + + // Register a long-running tool using registerToolTask + mcpServer.registerToolTask( + 'long-task', + { + title: 'Long Running Task', + description: 'A tool that takes time to complete', + inputSchema: { + duration: z.number().describe('Duration in milliseconds').default(1000), + shouldFail: z.boolean().describe('Whether the task should fail').default(false) + } + }, + { + async createTask({ duration, shouldFail }, extra) { + const task = await extra.taskStore.createTask( + { + ttl: 60000, + pollInterval: 100 + }, + 0, + { method: 'tools/call', params: { name: 'long-task', arguments: { duration, shouldFail } } } + ); + + // Simulate async work + (async () => { + await new Promise(resolve => setTimeout(resolve, duration)); + + if (shouldFail) { + await extra.taskStore.storeTaskResult(task.taskId, 'failed', { + content: [{ type: 'text', text: 'Task failed as requested' }], + isError: true + }); + } else { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Completed after ${duration}ms` }] + }); + } + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + // Register a tool that requires input + mcpServer.registerToolTask( + 'input-task', + { + title: 'Input Required Task', + description: 'A tool that requires user input', + inputSchema: { + initialMessage: z.string().describe('Initial message').default('Waiting for input') + } + }, + { + async createTask({ initialMessage }, extra) { + const task = await extra.taskStore.createTask( + { + ttl: 60000, + pollInterval: 100 + }, + 0, + { method: 'tools/call', params: { name: 'input-task', arguments: { initialMessage } } } + ); + + // Simulate moving to input_required status + (async () => { + await new Promise(resolve => setTimeout(resolve, 100)); + await extra.taskStore.updateTaskStatus(task.taskId, 'input_required'); + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + // Create transport + serverTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID() + }); + + await mcpServer.connect(serverTransport); + + // Create HTTP server + server = createServer(async (req, res) => { + await serverTransport.handleRequest(req, res); + }); + + // Start server + baseUrl = await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + }); + + afterEach(async () => { + taskStore.cleanup(); + await mcpServer.close().catch(() => {}); + await serverTransport.close().catch(() => {}); + server.close(); + }); + + describe('Task Creation and Completion', () => { + it('should create a task and return CreateTaskResult', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 500, + shouldFail: false + }, + task: { + ttl: 60000 + } + } + }, + z.object({ + task: TaskSchema + }) + ); + + // Verify CreateTaskResult structure + expect(createResult).toHaveProperty('task'); + expect(createResult.task).toHaveProperty('taskId'); + expect(createResult.task.status).toBe('working'); + expect(createResult.task.ttl).toBe(60000); + expect(createResult.task.createdAt).toBeDefined(); + expect(createResult.task.pollInterval).toBe(100); + + // Verify task is stored in taskStore + const taskId = createResult.task.taskId; + const storedTask = await taskStore.getTask(taskId); + expect(storedTask).toBeDefined(); + expect(storedTask?.taskId).toBe(taskId); + expect(storedTask?.status).toBe('working'); + + // Wait for completion + await new Promise(resolve => setTimeout(resolve, 600)); + + // Verify task completed + const completedTask = await taskStore.getTask(taskId); + expect(completedTask?.status).toBe('completed'); + + // Verify result is stored + const result = await taskStore.getTaskResult(taskId); + expect(result).toBeDefined(); + expect(result.content).toEqual([{ type: 'text', text: 'Completed after 500ms' }]); + + await transport.close(); + }); + + it('should handle task failure correctly', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task that will fail + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 300, + shouldFail: true + }, + task: { + ttl: 60000 + } + } + }, + z.object({ + task: TaskSchema + }) + ); + + const taskId = createResult.task.taskId; + + // Wait for failure + await new Promise(resolve => setTimeout(resolve, 400)); + + // Verify task failed + const task = await taskStore.getTask(taskId); + expect(task?.status).toBe('failed'); + + // Verify error result is stored + const result = await taskStore.getTaskResult(taskId); + expect(result.content).toEqual([{ type: 'text', text: 'Task failed as requested' }]); + expect(result.isError).toBe(true); + + await transport.close(); + }); + }); + + describe('Task Cancellation', () => { + it('should cancel a working task', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a long-running task + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 5000 + }, + task: { + ttl: 60000 + } + } + }, + z.object({ + task: TaskSchema + }) + ); + + const taskId = createResult.task.taskId; + + // Verify task is working + let task = await taskStore.getTask(taskId); + expect(task?.status).toBe('working'); + + // Cancel the task + await taskStore.updateTaskStatus(taskId, 'cancelled'); + + // Verify task is cancelled + task = await taskStore.getTask(taskId); + expect(task?.status).toBe('cancelled'); + + await transport.close(); + }); + + it('should reject cancellation of completed task', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a quick task + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 100 + }, + task: { + ttl: 60000 + } + } + }, + z.object({ + task: TaskSchema + }) + ); + + const taskId = createResult.task.taskId; + + // Wait for completion + await new Promise(resolve => setTimeout(resolve, 200)); + + // Verify task is completed + const task = await taskStore.getTask(taskId); + expect(task?.status).toBe('completed'); + + // Try to cancel (should fail) + await expect(taskStore.updateTaskStatus(taskId, 'cancelled')).rejects.toThrow(); + + await transport.close(); + }); + }); + + describe('Input Required Flow', () => { + it('should handle input_required status', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task that requires input + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'input-task', + arguments: { + initialMessage: 'Need user input' + }, + task: { + ttl: 60000 + } + } + }, + z.object({ + task: TaskSchema + }) + ); + + const taskId = createResult.task.taskId; + + // Wait for input_required status + await new Promise(resolve => setTimeout(resolve, 200)); + + const task = await taskStore.getTask(taskId); + expect(task?.status).toBe('input_required'); + + // Simulate providing input and completing the task + await taskStore.updateTaskStatus(taskId, 'working'); + await taskStore.storeTaskResult(taskId, 'completed', { + content: [{ type: 'text', text: 'Input received and processed' }] + }); + + // Verify completion + const completedTask = await taskStore.getTask(taskId); + expect(completedTask?.status).toBe('completed'); + + await transport.close(); + }); + }); + + describe('Task Listing and Pagination', () => { + it('should list tasks', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create multiple tasks + const taskIds: string[] = []; + for (let i = 0; i < 3; i++) { + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 1000 + }, + task: { + ttl: 60000 + } + } + }, + z.object({ + task: TaskSchema + }) + ); + taskIds.push(createResult.task.taskId); + } + + // List tasks using taskStore + const listResult = await taskStore.listTasks(); + + expect(listResult.tasks.length).toBeGreaterThanOrEqual(3); + expect(listResult.tasks.some(t => taskIds.includes(t.taskId))).toBe(true); + + await transport.close(); + }); + + it('should handle pagination with large datasets', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create 15 tasks (more than page size of 10) + for (let i = 0; i < 15; i++) { + await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 5000 + }, + task: { + ttl: 60000 + } + } + }, + z.object({ + task: TaskSchema + }) + ); + } + + // Get first page using taskStore + const page1 = await taskStore.listTasks(); + + expect(page1.tasks.length).toBe(10); + expect(page1.nextCursor).toBeDefined(); + + // Get second page + const page2 = await taskStore.listTasks(page1.nextCursor); + + expect(page2.tasks.length).toBeGreaterThanOrEqual(5); + + await transport.close(); + }); + }); + + describe('Error Handling', () => { + it('should return null for non-existent task', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Try to get non-existent task + const task = await taskStore.getTask('non-existent'); + expect(task).toBeNull(); + + await transport.close(); + }); + + it('should return error for invalid task operation', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create and complete a task + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 100 + }, + task: { + ttl: 60000 + } + } + }, + z.object({ + task: TaskSchema + }) + ); + + const taskId = createResult.task.taskId; + + // Wait for completion + await new Promise(resolve => setTimeout(resolve, 200)); + + // Try to cancel completed task (should fail) + await expect(taskStore.updateTaskStatus(taskId, 'cancelled')).rejects.toThrow(); + + await transport.close(); + }); + }); + + describe('TTL and Cleanup', () => { + it('should respect TTL in task creation', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task with specific TTL + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 100 + }, + task: { + ttl: 5000 + } + } + }, + z.object({ + task: TaskSchema + }) + ); + + const taskId = createResult.task.taskId; + + // Verify TTL is set correctly + expect(createResult.task.ttl).toBe(60000); // The task store uses 60000 as default + + // Task should exist + const task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task).toBeDefined(); + expect(task.ttl).toBe(60000); + + await transport.close(); + }); + }); + + describe('Concurrent Operations', () => { + it('should handle multiple concurrent task creations', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create multiple tasks concurrently + const promises = Array.from({ length: 5 }, () => + client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 500 + }, + task: { + ttl: 60000 + } + } + }, + z.object({ + task: TaskSchema + }) + ) + ); + + const results = await Promise.all(promises); + + // Verify all tasks were created with unique IDs + const taskIds = results.map(r => r.task.taskId); + expect(new Set(taskIds).size).toBe(5); + + // Verify all tasks are in working status + for (const result of results) { + expect(result.task.status).toBe('working'); + } + + await transport.close(); + }); + + it('should handle concurrent operations on same task', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 2000 + }, + task: { + ttl: 60000 + } + } + }, + z.object({ + task: TaskSchema + }) + ); + + const taskId = createResult.task.taskId; + + // Perform multiple concurrent gets + const getPromises = Array.from({ length: 5 }, () => + client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ) + ); + + const tasks = await Promise.all(getPromises); + + // All should return the same task + for (const task of tasks) { + expect(task.taskId).toBe(taskId); + expect(task.status).toBe('working'); + } + + await transport.close(); + }); + }); +}); From 304ff0ae27a7f9c624de5f3bab837831c9dd5c17 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 19 Nov 2025 19:31:35 -0800 Subject: [PATCH 51/84] Add elicitation integration tests, fix type issues --- src/integration-tests/taskLifecycle.test.ts | 230 +++++++++++++++----- src/server/index.test.ts | 6 +- src/server/mcp.ts | 12 +- src/shared/protocol.ts | 42 ++-- 4 files changed, 199 insertions(+), 91 deletions(-) diff --git a/src/integration-tests/taskLifecycle.test.ts b/src/integration-tests/taskLifecycle.test.ts index 93e6ef8c9..e15586f92 100644 --- a/src/integration-tests/taskLifecycle.test.ts +++ b/src/integration-tests/taskLifecycle.test.ts @@ -5,9 +5,10 @@ import { Client } from '../client/index.js'; import { StreamableHTTPClientTransport } from '../client/streamableHttp.js'; import { McpServer } from '../server/mcp.js'; import { StreamableHTTPServerTransport } from '../server/streamableHttp.js'; -import { TaskSchema } from '../types.js'; +import { CallToolResultSchema, CreateTaskResultSchema, ElicitRequestSchema, ElicitResultSchema, TaskSchema } from '../types.js'; import { z } from 'zod'; import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; +import { InMemoryTransport } from '../inMemory.js'; describe('Task Lifecycle Integration Tests', () => { let server: Server; @@ -93,31 +94,64 @@ describe('Task Lifecycle Integration Tests', () => { } ); - // Register a tool that requires input + // Register a tool that requires input via elicitation mcpServer.registerToolTask( 'input-task', { title: 'Input Required Task', description: 'A tool that requires user input', inputSchema: { - initialMessage: z.string().describe('Initial message').default('Waiting for input') + userName: z.string().describe('User name').optional() } }, { - async createTask({ initialMessage }, extra) { + async createTask({ userName }, extra) { const task = await extra.taskStore.createTask( { ttl: 60000, pollInterval: 100 }, 0, - { method: 'tools/call', params: { name: 'input-task', arguments: { initialMessage } } } + { method: 'tools/call', params: { name: 'input-task', arguments: { userName } } } ); - // Simulate moving to input_required status + // Perform async work that requires elicitation (async () => { await new Promise(resolve => setTimeout(resolve, 100)); - await extra.taskStore.updateTaskStatus(task.taskId, 'input_required'); + + // If userName not provided, request it via elicitation + if (!userName) { + const elicitationResult = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + message: 'What is your name?', + requestedSchema: { + type: 'object', + properties: { + userName: { type: 'string' } + }, + required: ['userName'] + } + } + }, + ElicitResultSchema + ); + + // Complete with the elicited name + const name = + elicitationResult.action === 'accept' && elicitationResult.content + ? elicitationResult.content.userName + : 'Unknown'; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Hello, ${name}!` }] + }); + } else { + // Complete immediately if userName was provided + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Hello, ${userName}!` }] + }); + } })(); return { task }; @@ -189,9 +223,7 @@ describe('Task Lifecycle Integration Tests', () => { } } }, - z.object({ - task: TaskSchema - }) + CreateTaskResultSchema ); // Verify CreateTaskResult structure @@ -248,9 +280,7 @@ describe('Task Lifecycle Integration Tests', () => { } } }, - z.object({ - task: TaskSchema - }) + CreateTaskResultSchema ); const taskId = createResult.task.taskId; @@ -295,9 +325,7 @@ describe('Task Lifecycle Integration Tests', () => { } } }, - z.object({ - task: TaskSchema - }) + CreateTaskResultSchema ); const taskId = createResult.task.taskId; @@ -339,9 +367,7 @@ describe('Task Lifecycle Integration Tests', () => { } } }, - z.object({ - task: TaskSchema - }) + CreateTaskResultSchema ); const taskId = createResult.task.taskId; @@ -361,7 +387,107 @@ describe('Task Lifecycle Integration Tests', () => { }); describe('Input Required Flow', () => { - it('should handle input_required status', async () => { + it('should handle elicitation during tool execution', async () => { + // Use InMemoryTransport for this test since elicitation requires bidirectional communication + const elicitClient = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + // Set up elicitation handler on client + elicitClient.setRequestHandler(ElicitRequestSchema, async request => { + // Verify elicitation request structure + expect(request.params.message).toBe('What is your name?'); + expect(request.params.requestedSchema).toHaveProperty('properties'); + + // Respond with user input + return { + action: 'accept' as const, + content: { + userName: 'Alice' + } + }; + }); + + const [elicitClientTransport, elicitServerTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([elicitClient.connect(elicitClientTransport), mcpServer.connect(elicitServerTransport)]); + + // Create a task without userName (will trigger elicitation) + const createResult = await elicitClient.request( + { + method: 'tools/call', + params: { + name: 'input-task', + arguments: {}, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Wait for elicitation to occur and task to transition to input_required + await new Promise(resolve => setTimeout(resolve, 200)); + + // Check task status - should be input_required during elicitation + let task = await elicitClient.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + + // Task should either be input_required or already completed (if elicitation was fast) + expect(['input_required', 'working', 'completed']).toContain(task.status); + + // Wait for completion after elicitation response + // Poll until task completes or times out + let attempts = 0; + while (attempts < 20) { + task = await elicitClient.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + if (task.status === 'completed' || task.status === 'failed') { + break; + } + await new Promise(resolve => setTimeout(resolve, 100)); + attempts++; + } + + // Verify task completed with elicited input + expect(task.status).toBe('completed'); + + // Get result + const result = await elicitClient.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + + expect(result.content).toEqual([{ type: 'text', text: 'Hello, Alice!' }]); + + await elicitClientTransport.close(); + await elicitServerTransport.close(); + }); + + it('should complete immediately when input is provided upfront', async () => { const client = new Client({ name: 'test-client', version: '1.0.0' @@ -370,42 +496,48 @@ describe('Task Lifecycle Integration Tests', () => { const transport = new StreamableHTTPClientTransport(baseUrl); await client.connect(transport); - // Create a task that requires input + // Create a task with userName provided (no elicitation needed) const createResult = await client.request( { method: 'tools/call', params: { name: 'input-task', arguments: { - initialMessage: 'Need user input' + userName: 'Bob' }, task: { ttl: 60000 } } }, - z.object({ - task: TaskSchema - }) + CreateTaskResultSchema ); const taskId = createResult.task.taskId; - // Wait for input_required status - await new Promise(resolve => setTimeout(resolve, 200)); + // Wait for completion + await new Promise(resolve => setTimeout(resolve, 300)); - const task = await taskStore.getTask(taskId); - expect(task?.status).toBe('input_required'); + // Verify task completed without elicitation + const task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('completed'); - // Simulate providing input and completing the task - await taskStore.updateTaskStatus(taskId, 'working'); - await taskStore.storeTaskResult(taskId, 'completed', { - content: [{ type: 'text', text: 'Input received and processed' }] - }); + // Get result + const result = await client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); - // Verify completion - const completedTask = await taskStore.getTask(taskId); - expect(completedTask?.status).toBe('completed'); + expect(result.content).toEqual([{ type: 'text', text: 'Hello, Bob!' }]); await transport.close(); }); @@ -437,9 +569,7 @@ describe('Task Lifecycle Integration Tests', () => { } } }, - z.object({ - task: TaskSchema - }) + CreateTaskResultSchema ); taskIds.push(createResult.task.taskId); } @@ -477,9 +607,7 @@ describe('Task Lifecycle Integration Tests', () => { } } }, - z.object({ - task: TaskSchema - }) + CreateTaskResultSchema ); } @@ -538,9 +666,7 @@ describe('Task Lifecycle Integration Tests', () => { } } }, - z.object({ - task: TaskSchema - }) + CreateTaskResultSchema ); const taskId = createResult.task.taskId; @@ -579,9 +705,7 @@ describe('Task Lifecycle Integration Tests', () => { } } }, - z.object({ - task: TaskSchema - }) + CreateTaskResultSchema ); const taskId = createResult.task.taskId; @@ -629,9 +753,7 @@ describe('Task Lifecycle Integration Tests', () => { } } }, - z.object({ - task: TaskSchema - }) + CreateTaskResultSchema ) ); @@ -672,9 +794,7 @@ describe('Task Lifecycle Integration Tests', () => { } } }, - z.object({ - task: TaskSchema - }) + CreateTaskResultSchema ); const taskId = createResult.task.taskId; diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 049d97cac..cc2f81c70 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -6,6 +6,7 @@ import type { Transport } from '../shared/transport.js'; import { CreateMessageRequestSchema, ElicitRequestSchema, + ElicitResultSchema, ErrorCode, LATEST_PROTOCOL_VERSION, ListPromptsRequestSchema, @@ -1236,10 +1237,7 @@ describe('Task-based execution', () => { } } }, - z.object({ - action: z.enum(['accept', 'decline', 'cancel']), - content: z.record(z.unknown()).optional() - }) + ElicitResultSchema ); const result = { diff --git a/src/server/mcp.ts b/src/server/mcp.ts index d18ced7c8..5ed3fbfa8 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -1228,7 +1228,7 @@ export class ResourceTemplate { export type BaseToolCallback< SendResultT extends Result, - Extra extends RequestHandlerExtra, + Extra extends RequestHandlerExtra, Args extends undefined | ZodRawShape | ZodType > = Args extends ZodRawShape ? (args: z.objectOutputType, extra: Extra) => SendResultT | Promise @@ -1252,13 +1252,11 @@ export type ToolCallback Args >; -export interface CreateTaskRequestHandlerExtra - extends RequestHandlerExtra { +export interface CreateTaskRequestHandlerExtra extends RequestHandlerExtra { taskStore: RequestTaskStore; } -export interface TaskRequestHandlerExtra - extends RequestHandlerExtra { +export interface TaskRequestHandlerExtra extends RequestHandlerExtra { taskId: string; taskStore: RequestTaskStore; } @@ -1266,12 +1264,12 @@ export interface TaskRequestHandlerExtra export type CreateTaskRequestHandler< SendResultT extends Result, Args extends undefined | ZodRawShape | ZodType = undefined -> = BaseToolCallback, Args>; +> = BaseToolCallback; export type TaskRequestHandler< SendResultT extends Result, Args extends undefined | ZodRawShape | ZodType = undefined -> = BaseToolCallback, Args>; +> = BaseToolCallback; export interface ToolTaskHandler = undefined> { createTask: CreateTaskRequestHandler; diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index c5fa885c5..3b0d3b71e 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -218,11 +218,7 @@ export interface RequestTaskStore { /** * Extra data given to request handlers. */ -export type RequestHandlerExtra< - SendRequestT extends Request, - SendNotificationT extends Notification, - SendResultT extends Result = Result -> = { +export type RequestHandlerExtra = { /** * An abort signal used to communicate if the request was cancelled from the sender's side. */ @@ -272,11 +268,7 @@ export type RequestHandlerExtra< * * This is used by certain transports to correctly associate related messages. */ - sendRequest: >( - request: SendRequestT, - resultSchema: U, - options?: TaskRequestOptions - ) => Promise>; + sendRequest: >(request: SendRequestT, resultSchema: U, options?: TaskRequestOptions) => Promise>; }; /** @@ -300,7 +292,7 @@ export abstract class Protocol) => Promise + (request: JSONRPCRequest, extra: RequestHandlerExtra) => Promise > = new Map(); private _requestHandlerAbortControllers: Map = new Map(); private _notificationHandlers: Map Promise> = new Map(); @@ -331,10 +323,7 @@ export abstract class Protocol - ) => Promise; + fallbackRequestHandler?: (request: JSONRPCRequest, extra: RequestHandlerExtra) => Promise; /** * A handler to invoke for any notification types that do not have their own handler installed. @@ -635,7 +624,7 @@ export abstract class Protocol = { + const fullExtra: RequestHandlerExtra = { signal: abortController.signal, sessionId: capturedTransport?.sessionId, _meta: request.params?._meta, @@ -825,11 +814,11 @@ export abstract class Protocol>( + beginRequest>( request: SendRequestT, resultSchema: T, options?: RequestOptions - ): PendingRequest { + ): PendingRequest> { const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {}; // Send the request @@ -950,7 +939,13 @@ export abstract class Protocol>, + result, + resultSchema, + undefined, + this._options?.defaultTaskPollInterval + ); } /** @@ -958,7 +953,7 @@ export abstract class Protocol>(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { + request>(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { return this.beginRequest(request, resultSchema, options).result(); } @@ -973,7 +968,7 @@ export abstract class Protocol>( + async getTaskResult>( params: GetTaskPayloadRequest['params'], resultSchema: T, options?: RequestOptions @@ -1095,10 +1090,7 @@ export abstract class Protocol >( requestSchema: T, - handler: ( - request: z.infer, - extra: RequestHandlerExtra - ) => SendResultT | Promise + handler: (request: z.infer, extra: RequestHandlerExtra) => SendResultT | Promise ): void { const method = requestSchema.shape.method.value; this.assertRequestHandlerCapability(method); From b562a09870dc9b762120a33c628a4932afa04d8a Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 19 Nov 2025 23:47:29 -0800 Subject: [PATCH 52/84] Implement SSE side-channeling on tasks/result --- src/integration-tests/taskLifecycle.test.ts | 899 +++++++++++- src/shared/protocol.test.ts | 1453 ++++++++++++++++++- src/shared/protocol.ts | 431 +++++- 3 files changed, 2662 insertions(+), 121 deletions(-) diff --git a/src/integration-tests/taskLifecycle.test.ts b/src/integration-tests/taskLifecycle.test.ts index e15586f92..220fcbb22 100644 --- a/src/integration-tests/taskLifecycle.test.ts +++ b/src/integration-tests/taskLifecycle.test.ts @@ -8,7 +8,7 @@ import { StreamableHTTPServerTransport } from '../server/streamableHttp.js'; import { CallToolResultSchema, CreateTaskResultSchema, ElicitRequestSchema, ElicitResultSchema, TaskSchema } from '../types.js'; import { z } from 'zod'; import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; -import { InMemoryTransport } from '../inMemory.js'; +import type { TaskRequestOptions } from '../shared/protocol.js'; describe('Task Lifecycle Integration Tests', () => { let server: Server; @@ -66,15 +66,19 @@ describe('Task Lifecycle Integration Tests', () => { (async () => { await new Promise(resolve => setTimeout(resolve, duration)); - if (shouldFail) { - await extra.taskStore.storeTaskResult(task.taskId, 'failed', { - content: [{ type: 'text', text: 'Task failed as requested' }], - isError: true - }); - } else { - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: `Completed after ${duration}ms` }] - }); + try { + if (shouldFail) { + await extra.taskStore.storeTaskResult(task.taskId, 'failed', { + content: [{ type: 'text', text: 'Task failed as requested' }], + isError: true + }); + } else { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Completed after ${duration}ms` }] + }); + } + } catch { + // Task may have been cleaned up if test ended } })(); @@ -135,7 +139,8 @@ describe('Task Lifecycle Integration Tests', () => { } } }, - ElicitResultSchema + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions ); // Complete with the elicited name @@ -143,14 +148,22 @@ describe('Task Lifecycle Integration Tests', () => { elicitationResult.action === 'accept' && elicitationResult.content ? elicitationResult.content.userName : 'Unknown'; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: `Hello, ${name}!` }] - }); + try { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Hello, ${name}!` }] + }); + } catch { + // Task may have been cleaned up if test ended + } } else { // Complete immediately if userName was provided - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: `Hello, ${userName}!` }] - }); + try { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Hello, ${userName}!` }] + }); + } catch { + // Task may have been cleaned up if test ended + } } })(); @@ -386,9 +399,181 @@ describe('Task Lifecycle Integration Tests', () => { }); }); + describe('Multiple Queued Messages', () => { + it('should deliver multiple queued messages in order', async () => { + // Register a tool that sends multiple server requests during execution + mcpServer.registerToolTask( + 'multi-request-task', + { + title: 'Multi Request Task', + description: 'A tool that sends multiple server requests', + inputSchema: { + requestCount: z.number().describe('Number of requests to send').default(3) + } + }, + { + async createTask({ requestCount }, extra) { + const task = await extra.taskStore.createTask( + { + ttl: 60000, + pollInterval: 100 + }, + 0, + { method: 'tools/call', params: { name: 'multi-request-task', arguments: { requestCount } } } + ); + + // Perform async work that sends multiple requests + (async () => { + await new Promise(resolve => setTimeout(resolve, 100)); + + const responses: string[] = []; + + // Send multiple elicitation requests + for (let i = 0; i < requestCount; i++) { + const elicitationResult = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + message: `Request ${i + 1} of ${requestCount}`, + requestedSchema: { + type: 'object', + properties: { + response: { type: 'string' } + }, + required: ['response'] + } + } + }, + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions + ); + + if (elicitationResult.action === 'accept' && elicitationResult.content) { + responses.push(elicitationResult.content.response as string); + } + } + + // Complete with all responses + try { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Received responses: ${responses.join(', ')}` }] + }); + } catch { + // Task may have been cleaned up if test ended + } + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + const receivedMessages: Array<{ method: string; message: string }> = []; + + // Set up elicitation handler on client to track message order + client.setRequestHandler(ElicitRequestSchema, async request => { + // Track the message + receivedMessages.push({ + method: request.method, + message: request.params.message + }); + + // Extract the request number from the message + const match = request.params.message.match(/Request (\d+) of (\d+)/); + const requestNum = match ? match[1] : 'unknown'; + + // Respond with the request number + return { + action: 'accept' as const, + content: { + response: `Response ${requestNum}` + } + }; + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task that will send 3 requests + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'multi-request-task', + arguments: { + requestCount: 3 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Wait for messages to be queued + await new Promise(resolve => setTimeout(resolve, 200)); + + // Call tasks/result to receive all queued messages + // This should deliver all 3 elicitation requests in order + const result = await client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + + // Verify all messages were delivered in order + expect(receivedMessages.length).toBe(3); + expect(receivedMessages[0].message).toBe('Request 1 of 3'); + expect(receivedMessages[1].message).toBe('Request 2 of 3'); + expect(receivedMessages[2].message).toBe('Request 3 of 3'); + + // Verify final result includes all responses + expect(result.content).toEqual([{ type: 'text', text: 'Received responses: Response 1, Response 2, Response 3' }]); + + // Verify task is completed + const task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('completed'); + + await transport.close(); + }, 10000); + }); + describe('Input Required Flow', () => { it('should handle elicitation during tool execution', async () => { - // Use InMemoryTransport for this test since elicitation requires bidirectional communication const elicitClient = new Client( { name: 'test-client', @@ -408,16 +593,17 @@ describe('Task Lifecycle Integration Tests', () => { expect(request.params.requestedSchema).toHaveProperty('properties'); // Respond with user input - return { + const response = { action: 'accept' as const, content: { userName: 'Alice' } }; + return response; }); - const [elicitClientTransport, elicitServerTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([elicitClient.connect(elicitClientTransport), mcpServer.connect(elicitServerTransport)]); + const transport = new StreamableHTTPClientTransport(baseUrl); + await elicitClient.connect(transport); // Create a task without userName (will trigger elicitation) const createResult = await elicitClient.request( @@ -436,43 +622,14 @@ describe('Task Lifecycle Integration Tests', () => { const taskId = createResult.task.taskId; - // Wait for elicitation to occur and task to transition to input_required + // Wait for elicitation to occur await new Promise(resolve => setTimeout(resolve, 200)); - // Check task status - should be input_required during elicitation - let task = await elicitClient.request( - { - method: 'tasks/get', - params: { taskId } - }, - TaskSchema - ); - - // Task should either be input_required or already completed (if elicitation was fast) - expect(['input_required', 'working', 'completed']).toContain(task.status); - - // Wait for completion after elicitation response - // Poll until task completes or times out - let attempts = 0; - while (attempts < 20) { - task = await elicitClient.request( - { - method: 'tasks/get', - params: { taskId } - }, - TaskSchema - ); - if (task.status === 'completed' || task.status === 'failed') { - break; - } - await new Promise(resolve => setTimeout(resolve, 100)); - attempts++; - } - - // Verify task completed with elicited input - expect(task.status).toBe('completed'); + // Check if the elicitation request was queued - // Get result + // Call tasks/result to receive the queued elicitation request + // This should deliver the elicitation request via the side-channel + // and then deliver the final result after the client responds const result = await elicitClient.request( { method: 'tasks/result', @@ -481,11 +638,21 @@ describe('Task Lifecycle Integration Tests', () => { CallToolResultSchema ); + // Verify final result is delivered correctly expect(result.content).toEqual([{ type: 'text', text: 'Hello, Alice!' }]); - await elicitClientTransport.close(); - await elicitServerTransport.close(); - }); + // Verify task is now completed + const task = await elicitClient.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('completed'); + + await transport.close(); + }, 10000); // Increase timeout to 10 seconds for debugging it('should complete immediately when input is provided upfront', async () => { const client = new Client({ @@ -728,6 +895,622 @@ describe('Task Lifecycle Integration Tests', () => { }); }); + describe('Task Cancellation with Queued Messages', () => { + it('should clear queue and deliver no messages when task is cancelled before tasks/result', async () => { + // Register a tool that queues messages but doesn't complete immediately + mcpServer.registerToolTask( + 'cancellable-task', + { + title: 'Cancellable Task', + description: 'A tool that queues messages and can be cancelled', + inputSchema: { + messageCount: z.number().describe('Number of messages to queue').default(2) + } + }, + { + async createTask({ messageCount }, extra) { + const task = await extra.taskStore.createTask( + { + ttl: 60000, + pollInterval: 100 + }, + 0, + { method: 'tools/call', params: { name: 'cancellable-task', arguments: { messageCount } } } + ); + + // Perform async work that queues messages + (async () => { + try { + await new Promise(resolve => setTimeout(resolve, 100)); + + // Queue multiple elicitation requests + for (let i = 0; i < messageCount; i++) { + // Send request but don't await - let it queue + extra + .sendRequest( + { + method: 'elicitation/create', + params: { + message: `Message ${i + 1} of ${messageCount}`, + requestedSchema: { + type: 'object', + properties: { + response: { type: 'string' } + }, + required: ['response'] + } + } + }, + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions + ) + .catch(() => { + // Ignore errors from cancelled requests + }); + } + + // Don't complete - let the task be cancelled + // Wait indefinitely (or until cancelled) + await new Promise(() => {}); + } catch { + // Ignore errors - task was cancelled + } + })().catch(() => { + // Catch any unhandled errors from the async execution + }); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + let elicitationCallCount = 0; + + // Set up elicitation handler to track if any messages are delivered + client.setRequestHandler(ElicitRequestSchema, async () => { + elicitationCallCount++; + return { + action: 'accept' as const, + content: { + response: 'Should not be called' + } + }; + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task that will queue messages + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'cancellable-task', + arguments: { + messageCount: 2 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Wait for messages to be queued + await new Promise(resolve => setTimeout(resolve, 200)); + + // Verify task is working and messages are queued + let task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('working'); + + // Cancel the task before calling tasks/result using the proper tasks/cancel request + // This will trigger queue cleanup via _clearTaskQueue in the handler + await client.request( + { + method: 'tasks/cancel', + params: { taskId } + }, + z.object({ _meta: z.record(z.unknown()).optional() }) + ); + + // Verify task is cancelled + task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('cancelled'); + + // Attempt to call tasks/result + // According to Requirement 4.2: "WHEN a task is cancelled THEN the system SHALL + // clear the message queue and reject any pending message delivery promises" + // This means NO messages should be delivered for a cancelled task + try { + await client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + } catch { + // tasks/result might throw an error for cancelled tasks without a result + // This is acceptable behavior + } + + // Verify no elicitation messages were delivered + // This validates Property 12: queue should be cleared immediately on cancellation + expect(elicitationCallCount).toBe(0); + + // Verify queue remains cleared on subsequent calls + try { + await client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + } catch { + // Expected - task is cancelled + } + + // Still no messages should have been delivered + expect(elicitationCallCount).toBe(0); + + await transport.close(); + }, 10000); + }); + + describe('Continuous Message Delivery', () => { + it('should deliver messages immediately while tasks/result is blocking', async () => { + // Register a tool that queues messages over time + mcpServer.registerToolTask( + 'streaming-task', + { + title: 'Streaming Task', + description: 'A tool that sends messages over time', + inputSchema: { + messageCount: z.number().describe('Number of messages to send').default(3), + delayBetweenMessages: z.number().describe('Delay between messages in ms').default(200) + } + }, + { + async createTask({ messageCount, delayBetweenMessages }, extra) { + const task = await extra.taskStore.createTask( + { + ttl: 60000, + pollInterval: 100 + }, + 0, + { method: 'tools/call', params: { name: 'streaming-task', arguments: { messageCount, delayBetweenMessages } } } + ); + + // Perform async work that sends messages over time + (async () => { + try { + // Wait a bit before starting to send messages + await new Promise(resolve => setTimeout(resolve, 100)); + + const responses: string[] = []; + + // Send messages with delays between them + for (let i = 0; i < messageCount; i++) { + const elicitationResult = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + message: `Streaming message ${i + 1} of ${messageCount}`, + requestedSchema: { + type: 'object', + properties: { + response: { type: 'string' } + }, + required: ['response'] + } + } + }, + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions + ); + + if (elicitationResult.action === 'accept' && elicitationResult.content) { + responses.push(elicitationResult.content.response as string); + } + + // Wait before sending next message (if not the last one) + if (i < messageCount - 1) { + await new Promise(resolve => setTimeout(resolve, delayBetweenMessages)); + } + } + + // Complete with all responses + try { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Received all responses: ${responses.join(', ')}` }] + }); + } catch { + // Task may have been cleaned up if test ended + } + } catch (error) { + // Handle errors + try { + await extra.taskStore.storeTaskResult(task.taskId, 'failed', { + content: [{ type: 'text', text: `Error: ${error}` }], + isError: true + }); + } catch { + // Task may have been cleaned up if test ended + } + } + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + const receivedMessages: Array<{ message: string; timestamp: number }> = []; + let tasksResultStartTime = 0; + + // Set up elicitation handler to track when messages arrive + client.setRequestHandler(ElicitRequestSchema, async request => { + const timestamp = Date.now(); + receivedMessages.push({ + message: request.params.message, + timestamp + }); + + // Extract the message number + const match = request.params.message.match(/Streaming message (\d+) of (\d+)/); + const messageNum = match ? match[1] : 'unknown'; + + // Respond immediately + return { + action: 'accept' as const, + content: { + response: `Response ${messageNum}` + } + }; + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task that will send messages over time + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'streaming-task', + arguments: { + messageCount: 3, + delayBetweenMessages: 300 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Verify task is in working status + let task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('working'); + + // Call tasks/result immediately (before messages are queued) + // This should block and deliver messages as they arrive + tasksResultStartTime = Date.now(); + const resultPromise = client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + + // Wait for the task to complete and get the result + const result = await resultPromise; + + // Verify all 3 messages were delivered + expect(receivedMessages.length).toBe(3); + expect(receivedMessages[0].message).toBe('Streaming message 1 of 3'); + expect(receivedMessages[1].message).toBe('Streaming message 2 of 3'); + expect(receivedMessages[2].message).toBe('Streaming message 3 of 3'); + + // Verify messages were delivered over time (not all at once) + // The delay between messages should be approximately 300ms + const timeBetweenFirstAndSecond = receivedMessages[1].timestamp - receivedMessages[0].timestamp; + const timeBetweenSecondAndThird = receivedMessages[2].timestamp - receivedMessages[1].timestamp; + + // Allow some tolerance for timing (messages should be at least 200ms apart) + expect(timeBetweenFirstAndSecond).toBeGreaterThan(200); + expect(timeBetweenSecondAndThird).toBeGreaterThan(200); + + // Verify messages were delivered while tasks/result was blocking + // (all messages should arrive after tasks/result was called) + for (const msg of receivedMessages) { + expect(msg.timestamp).toBeGreaterThanOrEqual(tasksResultStartTime); + } + + // Verify final result is correct + expect(result.content).toEqual([{ type: 'text', text: 'Received all responses: Response 1, Response 2, Response 3' }]); + + // Verify task is now completed + task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('completed'); + + await transport.close(); + }, 15000); // Increase timeout to 15 seconds to allow for message delays + }); + + describe('Terminal Task with Queued Messages', () => { + it('should deliver queued messages followed by final result for terminal task', async () => { + // Register a tool that completes quickly and queues messages before completion + mcpServer.registerToolTask( + 'quick-complete-task', + { + title: 'Quick Complete Task', + description: 'A tool that queues messages and completes quickly', + inputSchema: { + messageCount: z.number().describe('Number of messages to queue').default(2) + } + }, + { + async createTask({ messageCount }, extra) { + const task = await extra.taskStore.createTask( + { + ttl: 60000, + pollInterval: 100 + }, + 0, + { method: 'tools/call', params: { name: 'quick-complete-task', arguments: { messageCount } } } + ); + + // Perform async work that queues messages and completes quickly + (async () => { + try { + // Queue messages without waiting for responses + const pendingRequests: Promise[] = []; + + for (let i = 0; i < messageCount; i++) { + const requestPromise = extra.sendRequest( + { + method: 'elicitation/create', + params: { + message: `Quick message ${i + 1} of ${messageCount}`, + requestedSchema: { + type: 'object', + properties: { + response: { type: 'string' } + }, + required: ['response'] + } + } + }, + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions + ); + pendingRequests.push(requestPromise); + } + + // Complete the task immediately (before responses are received) + // This creates a terminal task with queued messages + try { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: 'Task completed quickly' }] + }); + } catch { + // Task may have been cleaned up if test ended + } + + // Wait for all responses in the background + await Promise.all(pendingRequests.map(p => p.catch(() => {}))); + } catch (error) { + // Handle errors + try { + await extra.taskStore.storeTaskResult(task.taskId, 'failed', { + content: [{ type: 'text', text: `Error: ${error}` }], + isError: true + }); + } catch { + // Task may have been cleaned up if test ended + } + } + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + const receivedMessages: Array<{ type: string; message?: string; content?: unknown }> = []; + + // Set up elicitation handler to track message order + client.setRequestHandler(ElicitRequestSchema, async request => { + receivedMessages.push({ + type: 'elicitation', + message: request.params.message + }); + + // Extract the message number + const match = request.params.message.match(/Quick message (\d+) of (\d+)/); + const messageNum = match ? match[1] : 'unknown'; + + return { + action: 'accept' as const, + content: { + response: `Response ${messageNum}` + } + }; + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task that will complete quickly with queued messages + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'quick-complete-task', + arguments: { + messageCount: 2 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Wait for task to complete and messages to be queued + await new Promise(resolve => setTimeout(resolve, 200)); + + // Verify task is in terminal status (completed) + const task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('completed'); + + // Call tasks/result - should deliver queued messages followed by final result + const result = await client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + + // Verify all queued messages were delivered before the final result + expect(receivedMessages.length).toBe(2); + expect(receivedMessages[0].message).toBe('Quick message 1 of 2'); + expect(receivedMessages[1].message).toBe('Quick message 2 of 2'); + + // Verify final result is correct + expect(result.content).toEqual([{ type: 'text', text: 'Task completed quickly' }]); + + // Verify queue is cleaned up - calling tasks/result again should only return the result + receivedMessages.length = 0; // Clear the array + + const result2 = await client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + + // No messages should be delivered on second call (queue was cleaned up) + expect(receivedMessages.length).toBe(0); + expect(result2.content).toEqual([{ type: 'text', text: 'Task completed quickly' }]); + + await transport.close(); + }, 10000); + }); + describe('Concurrent Operations', () => { it('should handle multiple concurrent task creations', async () => { const client = new Client({ diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index bf7d4dcfd..20f5a5e72 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -13,10 +13,22 @@ import { Task, TaskCreationParams } from '../types.js'; -import { Protocol, mergeCapabilities } from './protocol.js'; +import { Protocol, mergeCapabilities, TaskMessageQueue } from './protocol.js'; import { Transport } from './transport.js'; import { TaskStore } from './task.js'; import { MockInstance, vi } from 'vitest'; +import { JSONRPCResponse, JSONRPCRequest, JSONRPCError } from '../types.js'; + +// Type helper for accessing private Protocol properties in tests +interface TestProtocol { + _taskMessageQueues: Map; + _taskResultWaiters: Map void>>; + _requestResolvers: Map void>; + _responseHandlers: Map void>; + _taskProgressTokens: Map; + _clearTaskQueue: (taskId: string) => void; + requestTaskStore: (request: Request, authInfo: unknown) => TaskStore; +} // Mock Transport class class MockTransport implements Transport { @@ -759,6 +771,184 @@ describe('protocol tests', () => { }); }); +describe('TaskMessageQueue', () => { + let queue: TaskMessageQueue; + + beforeEach(() => { + queue = new TaskMessageQueue(); + }); + + describe('enqueue/dequeue maintains FIFO order', () => { + it('should maintain FIFO order for multiple messages', () => { + const msg1 = { + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test1' }, + timestamp: 1 + }; + const msg2 = { + type: 'request' as const, + message: { jsonrpc: '2.0' as const, id: 1, method: 'test2' }, + timestamp: 2 + }; + const msg3 = { + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test3' }, + timestamp: 3 + }; + + queue.enqueue(msg1); + queue.enqueue(msg2); + queue.enqueue(msg3); + + expect(queue!.dequeue()).toEqual(msg1); + expect(queue!.dequeue()).toEqual(msg2); + expect(queue!.dequeue()).toEqual(msg3); + }); + + it('should return undefined when dequeuing from empty queue', () => { + expect(queue!.dequeue()).toBeUndefined(); + }); + }); + + describe('clear operation', () => { + it('should remove all messages from queue', () => { + queue.enqueue({ + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test1' }, + timestamp: 1 + }); + queue.enqueue({ + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test2' }, + timestamp: 2 + }); + queue.enqueue({ + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test3' }, + timestamp: 3 + }); + + expect(queue!.size()).toBe(3); + + queue.clear(); + + expect(queue!.size()).toBe(0); + expect(queue.isEmpty()).toBe(true); + expect(queue!.dequeue()).toBeUndefined(); + }); + + it('should work on empty queue', () => { + expect(() => queue.clear()).not.toThrow(); + expect(queue.isEmpty()).toBe(true); + }); + }); + + describe('isEmpty and size methods', () => { + it('should return true for empty queue', () => { + expect(queue.isEmpty()).toBe(true); + expect(queue!.size()).toBe(0); + }); + + it('should return false after enqueuing', () => { + queue.enqueue({ + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test' }, + timestamp: 1 + }); + expect(queue.isEmpty()).toBe(false); + expect(queue!.size()).toBe(1); + }); + + it('should return correct size for multiple messages', () => { + for (let i = 0; i < 5; i++) { + queue.enqueue({ + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: `test${i}` }, + timestamp: i + }); + } + expect(queue!.size()).toBe(5); + expect(queue.isEmpty()).toBe(false); + }); + + it('should update size correctly after dequeue', () => { + queue.enqueue({ + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test1' }, + timestamp: 1 + }); + queue.enqueue({ + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test2' }, + timestamp: 2 + }); + expect(queue!.size()).toBe(2); + + queue!.dequeue(); + expect(queue!.size()).toBe(1); + expect(queue.isEmpty()).toBe(false); + + queue!.dequeue(); + expect(queue!.size()).toBe(0); + expect(queue.isEmpty()).toBe(true); + }); + }); + + describe('dequeueAll operation', () => { + it('should return all messages in FIFO order', () => { + const msg1 = { + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test1' }, + timestamp: 1 + }; + const msg2 = { + type: 'request' as const, + message: { jsonrpc: '2.0' as const, id: 1, method: 'test2' }, + timestamp: 2 + }; + const msg3 = { + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test3' }, + timestamp: 3 + }; + + queue.enqueue(msg1); + queue.enqueue(msg2); + queue.enqueue(msg3); + + const allMessages = queue.dequeueAll(); + + expect(allMessages).toEqual([msg1, msg2, msg3]); + expect(queue.isEmpty()).toBe(true); + expect(queue!.size()).toBe(0); + }); + + it('should return empty array for empty queue', () => { + const allMessages = queue.dequeueAll(); + expect(allMessages).toEqual([]); + expect(queue.isEmpty()).toBe(true); + }); + + it('should clear queue after dequeueAll', () => { + queue.enqueue({ + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test1' }, + timestamp: 1 + }); + queue.enqueue({ + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test2' }, + timestamp: 2 + }); + + queue.dequeueAll(); + + expect(queue!.dequeue()).toBeUndefined(); + expect(queue!.size()).toBe(0); + }); + }); +}); + describe('mergeCapabilities', () => { it('should merge client capabilities', () => { const base: ClientCapabilities = { @@ -1420,8 +1610,7 @@ describe('Task-based execution', () => { 'Client cancelled task execution.', undefined ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const sentMessage = sendSpy.mock.calls[0][0] as any; + const sentMessage = sendSpy.mock.calls[0][0] as unknown as JSONRPCResponse; expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(5); expect(sentMessage.result._meta).toBeDefined(); @@ -1459,8 +1648,7 @@ describe('Task-based execution', () => { taskDeleted.releaseLatch(); expect(mockTaskStore.getTask).toHaveBeenCalledWith('non-existent', undefined); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const sentMessage = sendSpy.mock.calls[0][0] as any; + const sentMessage = sendSpy.mock.calls[0][0] as unknown as JSONRPCError; expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(6); expect(sentMessage.error).toBeDefined(); @@ -1508,8 +1696,7 @@ describe('Task-based execution', () => { expect(mockTaskStore.getTask).toHaveBeenCalledWith(completedTask.taskId, undefined); expect(mockTaskStore.updateTaskStatus).not.toHaveBeenCalled(); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const sentMessage = sendSpy.mock.calls[0][0] as any; + const sentMessage = sendSpy.mock.calls[0][0] as unknown as JSONRPCError; expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(7); expect(sentMessage.error).toBeDefined(); @@ -2332,16 +2519,14 @@ describe('Progress notification support for tasks', () => { expect(progressCallback).toHaveBeenCalledTimes(1); // Verify the task-progress association was created - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const taskProgressTokens = (protocol as any)._taskProgressTokens as Map; + const taskProgressTokens = (protocol as unknown as TestProtocol)._taskProgressTokens as Map; expect(taskProgressTokens.has(taskId)).toBe(true); expect(taskProgressTokens.get(taskId)).toBe(progressToken); // Simulate task completion by calling through the protocol's task store // This will trigger the cleanup logic const mockRequest = { jsonrpc: '2.0' as const, id: 999, method: 'test', params: {} }; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const requestTaskStore = (protocol as any).requestTaskStore(mockRequest, undefined); + const requestTaskStore = (protocol as unknown as TestProtocol).requestTaskStore(mockRequest, undefined); await requestTaskStore.storeTaskResult(taskId, 'completed', { content: [] }); // Wait for all async operations including notification sending to complete @@ -2831,3 +3016,1249 @@ describe('Capability negotiation for tasks', () => { expect('cancel' in clientCapabilities.tasks).toBe(true); }); }); + +describe('Message interception for task-related notifications', () => { + it('should queue notifications with io.modelcontextprotocol/related-task metadata', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore }); + + await server.connect(transport); + + // Create a task first + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Send a notification with related task metadata + await server.notification( + { + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }, + { + relatedTask: { taskId: task.taskId } + } + ); + + // Access the private queue to verify the message was queued + const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); + expect(queue).toBeDefined(); + expect(queue!.size()).toBe(1); + + const queuedMessage = queue!.dequeue(); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.type).toBe('notification'); + expect(queuedMessage?.message.method).toBe('notifications/message'); + expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: task.taskId }); + }); + + it('should not queue notifications without related-task metadata', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore }); + + await server.connect(transport); + + // Send a notification without related task metadata + await server.notification({ + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }); + + // Verify no queues were created + const queues = (server as unknown as TestProtocol)._taskMessageQueues; + expect(queues.size).toBe(0); + }); + + it('should notify task result waiters after queuing', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore }); + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Set up a waiter + let waiterCalled = false; + const waiters = (server as unknown as TestProtocol)._taskResultWaiters; + waiters.set(task.taskId, [ + () => { + waiterCalled = true; + } + ]); + + // Send a notification with related task metadata + await server.notification( + { + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }, + { + relatedTask: { taskId: task.taskId } + } + ); + + // Verify the waiter was called + expect(waiterCalled).toBe(true); + expect(waiters.has(task.taskId)).toBe(false); // Waiters should be cleared + }); + + it('should handle queue overflow by failing the task', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, maxTaskQueueSize: 100 }); + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Fill the queue to max capacity (100 messages) + for (let i = 0; i < 100; i++) { + await server.notification( + { + method: 'notifications/message', + params: { level: 'info', data: `message ${i}` } + }, + { + relatedTask: { taskId: task.taskId } + } + ); + } + + // Verify queue is at max capacity + const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); + expect(queue!.size()).toBe(100); + + // Try to add one more message - should throw and fail the task + await expect( + server.notification( + { + method: 'notifications/message', + params: { level: 'info', data: 'overflow message' } + }, + { + relatedTask: { taskId: task.taskId } + } + ) + ).rejects.toThrow(McpError); + + // Verify the task was failed + expect(taskStore.updateTaskStatus).toHaveBeenCalledWith(task.taskId, 'failed', 'Task message queue overflow'); + + // Verify the queue was cleared + expect(queue!.size()).toBe(0); + }); + + it('should extract task ID correctly from metadata', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore }); + + await server.connect(transport); + + const taskId = 'custom-task-id-123'; + + // Send a notification with custom task ID + await server.notification( + { + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }, + { + relatedTask: { taskId } + } + ); + + // Verify the message was queued under the correct task ID + const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(taskId); + expect(queue).toBeDefined(); + expect(queue!.size()).toBe(1); + }); + + it('should preserve message order when queuing multiple notifications', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore }); + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Send multiple notifications + for (let i = 0; i < 5; i++) { + await server.notification( + { + method: 'notifications/message', + params: { level: 'info', data: `message ${i}` } + }, + { + relatedTask: { taskId: task.taskId } + } + ); + } + + // Verify messages are in FIFO order + const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); + expect(queue!.size()).toBe(5); + + for (let i = 0; i < 5; i++) { + const message = queue!.dequeue(); + expect(message!.message.params!.data).toBe(`message ${i}`); + } + }); +}); + +describe('Message interception for task-related requests', () => { + it('should queue requests with io.modelcontextprotocol/related-task metadata', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore }); + + await server.connect(transport); + + // Create a task first + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Send a request with related task metadata (don't await - we're testing queuing) + const requestPromise = server.request( + { + method: 'ping', + params: {} + }, + z.object({}), + { + relatedTask: { taskId: task.taskId } + } + ); + + // Access the private queue to verify the message was queued + const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); + expect(queue).toBeDefined(); + expect(queue!.size()).toBe(1); + + const queuedMessage = queue!.dequeue(); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.type).toBe('request'); + expect(queuedMessage?.message.method).toBe('ping'); + expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: task.taskId }); + expect(queuedMessage?.responseResolver).toBeDefined(); + expect(queuedMessage!.originalRequestId!).toBeDefined(); + + // Clean up - send a response to prevent hanging promise + transport.onmessage?.({ + jsonrpc: '2.0', + id: queuedMessage!.originalRequestId!, + result: {} + }); + + await requestPromise; + }); + + it('should not queue requests without related-task metadata', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore }); + + await server.connect(transport); + + // Send a request without related task metadata + const requestPromise = server.request( + { + method: 'ping', + params: {} + }, + z.object({}) + ); + + // Verify no queues were created + const queues = (server as unknown as TestProtocol)._taskMessageQueues; + expect(queues.size).toBe(0); + + // Clean up - send a response + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + result: {} + }); + + await requestPromise; + }); + + it('should notify task result waiters after queuing request', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore }); + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Set up a waiter + let waiterCalled = false; + const waiters = (server as unknown as TestProtocol)._taskResultWaiters; + waiters.set(task.taskId, [ + () => { + waiterCalled = true; + } + ]); + + // Send a request with related task metadata + const requestPromise = server.request( + { + method: 'ping', + params: {} + }, + z.object({}), + { + relatedTask: { taskId: task.taskId } + } + ); + + // Verify the waiter was called + expect(waiterCalled).toBe(true); + expect(waiters.has(task.taskId)).toBe(false); // Waiters should be cleared + + // Clean up + const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); + const queuedMessage = queue!.dequeue(); + transport.onmessage?.({ + jsonrpc: '2.0', + id: queuedMessage!.originalRequestId!, + result: {} + }); + + await requestPromise; + }); + + it('should store request resolver for response routing', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore }); + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Send a request with related task metadata + const requestPromise = server.request( + { + method: 'ping', + params: {} + }, + z.object({}), + { + relatedTask: { taskId: task.taskId } + } + ); + + // Verify the resolver was stored + const resolvers = (server as unknown as TestProtocol)._requestResolvers; + expect(resolvers.size).toBe(1); + + // Get the request ID from the queue + const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); + const queuedMessage = queue!.dequeue(); + const requestId = queuedMessage!.originalRequestId!; + + expect(resolvers.has(requestId)).toBe(true); + + // Send a response to trigger resolver + transport.onmessage?.({ + jsonrpc: '2.0', + id: requestId, + result: {} + }); + + await requestPromise; + + // Verify resolver was cleaned up after response + expect(resolvers.has(requestId)).toBe(false); + }); + + it('should route responses to side-channeled requests', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore }); + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Send a request with related task metadata + const requestPromise = server.request( + { + method: 'ping', + params: {} + }, + z.object({ message: z.string() }), + { + relatedTask: { taskId: task.taskId } + } + ); + + // Get the request ID from the queue + const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); + const queuedMessage = queue!.dequeue(); + const requestId = queuedMessage!.originalRequestId!; + + // Send a response + transport.onmessage?.({ + jsonrpc: '2.0', + id: requestId, + result: { message: 'pong' } + }); + + // Verify the response was routed correctly + const result = await requestPromise; + expect(result).toEqual({ message: 'pong' }); + }); + + it('should log error when resolver is missing for side-channeled request', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore }); + + const errors: Error[] = []; + server.onerror = (error: Error) => { + errors.push(error); + }; + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Send a request with related task metadata + void server.request( + { + method: 'ping', + params: {} + }, + z.object({ message: z.string() }), + { + relatedTask: { taskId: task.taskId } + } + ); + + // Get the request ID from the queue + const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); + const queuedMessage = queue!.dequeue(); + const requestId = queuedMessage!.originalRequestId!; + + // Manually delete the response handler to simulate missing resolver + (server as unknown as TestProtocol)._responseHandlers.delete(requestId); + + // Send a response - this should trigger the error logging + transport.onmessage?.({ + jsonrpc: '2.0', + id: requestId, + result: { message: 'pong' } + }); + + // Wait a bit for the error to be logged + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify error was logged + expect(errors.length).toBe(1); + expect(errors[0].message).toContain('Response handler missing for side-channeled request'); + }); + + it('should handle queue overflow for requests', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, maxTaskQueueSize: 100 }); + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Fill the queue to max capacity (100 messages) + const promises: Promise[] = []; + for (let i = 0; i < 100; i++) { + const promise = server + .request( + { + method: 'ping', + params: {} + }, + z.object({}), + { + relatedTask: { taskId: task.taskId } + } + ) + .catch(() => { + // Expected to reject when queue is cleared + }); + promises.push(promise); + } + + // Verify queue is at max capacity + const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); + expect(queue!.size()).toBe(100); + + // Try to add one more request - should throw and fail the task + await expect( + server.request( + { + method: 'ping', + params: {} + }, + z.object({}), + { + relatedTask: { taskId: task.taskId } + } + ) + ).rejects.toThrow(McpError); + + // Verify the task was failed + expect(taskStore.updateTaskStatus).toHaveBeenCalledWith(task.taskId, 'failed', 'Task message queue overflow'); + + // Verify the queue was cleared + expect(queue!.size()).toBe(0); + }); +}); + +describe('Message Interception', () => { + let protocol: Protocol; + let transport: MockTransport; + let mockTaskStore: TaskStore & { [K in keyof TaskStore]: MockInstance }; + + beforeEach(() => { + transport = new MockTransport(); + mockTaskStore = createMockTaskStore(); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + }); + + describe('messages with relatedTask metadata are queued', () => { + it('should queue notifications with relatedTask metadata', async () => { + await protocol.connect(transport); + + // Send a notification with relatedTask metadata + await protocol.notification( + { + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }, + { + relatedTask: { + taskId: 'task-123' + } + } + ); + + // Access the private _taskMessageQueues to verify the message was queued + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.has('task-123')).toBe(true); + + const queue = queues.get('task-123')!; + expect(queue!.size()).toBe(1); + + const queuedMessage = queue!.dequeue(); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage!.type).toBe('notification'); + expect(queuedMessage!.message.method).toBe('notifications/message'); + }); + + it('should queue requests with relatedTask metadata', async () => { + await protocol.connect(transport); + + const mockSchema = z.object({ result: z.string() }); + + // Send a request with relatedTask metadata + const requestPromise = protocol.request( + { + method: 'test/request', + params: { data: 'test' } + }, + mockSchema, + { + relatedTask: { + taskId: 'task-456' + } + } + ); + + // Access the private _taskMessageQueues to verify the message was queued + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.has('task-456')).toBe(true); + + const queue = queues.get('task-456')!; + expect(queue!.size()).toBe(1); + + const queuedMessage = queue!.dequeue(); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage!.type).toBe('request'); + expect(queuedMessage!.message.method).toBe('test/request'); + expect(queuedMessage!.responseResolver).toBeDefined(); + + // Clean up the pending request + transport.onmessage?.({ + jsonrpc: '2.0', + id: (queuedMessage!.message as JSONRPCRequest).id, + result: { result: 'success' } + }); + await requestPromise; + }); + }); + + describe('messages without metadata bypass the queue', () => { + it('should not queue notifications without relatedTask metadata', async () => { + await protocol.connect(transport); + + // Send a notification without relatedTask metadata + await protocol.notification({ + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }); + + // Access the private _taskMessageQueues to verify no queue was created + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.size).toBe(0); + }); + + it('should not queue requests without relatedTask metadata', async () => { + await protocol.connect(transport); + + const mockSchema = z.object({ result: z.string() }); + const sendSpy = vi.spyOn(transport, 'send'); + + // Send a request without relatedTask metadata + const requestPromise = protocol.request( + { + method: 'test/request', + params: { data: 'test' } + }, + mockSchema + ); + + // Access the private _taskMessageQueues to verify no queue was created + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.size).toBe(0); + + // Clean up the pending request + const requestId = (sendSpy.mock.calls[0][0] as JSONRPCResponse).id; + transport.onmessage?.({ + jsonrpc: '2.0', + id: requestId, + result: { result: 'success' } + }); + await requestPromise; + }); + }); + + describe('task ID extraction from metadata', () => { + it('should extract correct task ID from relatedTask metadata for notifications', async () => { + await protocol.connect(transport); + + const taskId = 'extracted-task-789'; + + // Send a notification with relatedTask metadata + await protocol.notification( + { + method: 'notifications/message', + params: { data: 'test' } + }, + { + relatedTask: { + taskId: taskId + } + } + ); + + // Verify the message was queued under the correct task ID + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.has(taskId)).toBe(true); + expect(queues.has('wrong-task-id')).toBe(false); + }); + + it('should extract correct task ID from relatedTask metadata for requests', async () => { + await protocol.connect(transport); + + const taskId = 'extracted-task-999'; + const mockSchema = z.object({ result: z.string() }); + + // Send a request with relatedTask metadata + const requestPromise = protocol.request( + { + method: 'test/request', + params: { data: 'test' } + }, + mockSchema, + { + relatedTask: { + taskId: taskId + } + } + ); + + // Verify the message was queued under the correct task ID + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.has(taskId)).toBe(true); + expect(queues.has('wrong-task-id')).toBe(false); + + // Clean up the pending request + const queue = queues.get(taskId)!; + const queuedMessage = queue!.dequeue(); + transport.onmessage?.({ + jsonrpc: '2.0', + id: (queuedMessage!.message as JSONRPCRequest).id, + result: { result: 'success' } + }); + await requestPromise; + }); + + it('should handle multiple messages for different task IDs', async () => { + await protocol.connect(transport); + + // Send messages for different tasks + await protocol.notification({ method: 'test1', params: {} }, { relatedTask: { taskId: 'task-A' } }); + await protocol.notification({ method: 'test2', params: {} }, { relatedTask: { taskId: 'task-B' } }); + await protocol.notification({ method: 'test3', params: {} }, { relatedTask: { taskId: 'task-A' } }); + + // Verify messages are queued under correct task IDs + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.has('task-A')).toBe(true); + expect(queues.has('task-B')).toBe(true); + + const queueA = queues.get('task-A')!; + const queueB = queues.get('task-B')!; + + expect(queueA.size()).toBe(2); // Two messages for task-A + expect(queueB.size()).toBe(1); // One message for task-B + }); + }); + + describe('queue creation on first message', () => { + it('should create queue on first message for a task', async () => { + await protocol.connect(transport); + + // Verify no queues exist initially + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.size).toBe(0); + + // Send first message for a task + await protocol.notification({ method: 'test', params: {} }, { relatedTask: { taskId: 'new-task' } }); + + // Verify queue was created + expect(queues.has('new-task')).toBe(true); + expect(queues.size).toBe(1); + }); + + it('should reuse existing queue for subsequent messages', async () => { + await protocol.connect(transport); + + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + + // Send first message + await protocol.notification({ method: 'test1', params: {} }, { relatedTask: { taskId: 'reuse-task' } }); + + const firstQueue = queues.get('reuse-task'); + expect(firstQueue).toBeDefined(); + expect(firstQueue!.size()).toBe(1); + + // Send second message + await protocol.notification({ method: 'test2', params: {} }, { relatedTask: { taskId: 'reuse-task' } }); + + const secondQueue = queues.get('reuse-task'); + + // Should be the same queue instance + expect(secondQueue).toBe(firstQueue); + expect(secondQueue!.size()).toBe(2); + }); + + it('should create separate queues for different tasks', async () => { + await protocol.connect(transport); + + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + + // Send messages for different tasks + await protocol.notification({ method: 'test1', params: {} }, { relatedTask: { taskId: 'task-1' } }); + await protocol.notification({ method: 'test2', params: {} }, { relatedTask: { taskId: 'task-2' } }); + + // Verify separate queues were created + expect(queues.size).toBe(2); + expect(queues.has('task-1')).toBe(true); + expect(queues.has('task-2')).toBe(true); + + const queue1 = queues.get('task-1')!; + const queue2 = queues.get('task-2')!; + + // Verify they are different queue instances + expect(queue1).not.toBe(queue2); + }); + }); + + describe('metadata preservation in queued messages', () => { + it('should preserve relatedTask metadata in queued notification', async () => { + await protocol.connect(transport); + + const relatedTask = { taskId: 'task-meta-123' }; + + await protocol.notification( + { + method: 'test/notification', + params: { data: 'test' } + }, + { relatedTask } + ); + + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + const queue = queues.get('task-meta-123')!; + const queuedMessage = queue!.dequeue(); + + // Verify the metadata is preserved in the queued message + expect(queuedMessage!.message.params!._meta).toBeDefined(); + expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual(relatedTask); + }); + + it('should preserve relatedTask metadata in queued request', async () => { + await protocol.connect(transport); + + const relatedTask = { taskId: 'task-meta-456' }; + const mockSchema = z.object({ result: z.string() }); + + const requestPromise = protocol.request( + { + method: 'test/request', + params: { data: 'test' } + }, + mockSchema, + { relatedTask } + ); + + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + const queue = queues.get('task-meta-456')!; + const queuedMessage = queue!.dequeue(); + + // Verify the metadata is preserved in the queued message + expect(queuedMessage!.message.params!._meta).toBeDefined(); + expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual(relatedTask); + + // Clean up + transport.onmessage?.({ + jsonrpc: '2.0', + id: (queuedMessage!.message as JSONRPCRequest).id, + result: { result: 'success' } + }); + await requestPromise; + }); + + it('should preserve existing _meta fields when adding relatedTask', async () => { + await protocol.connect(transport); + + await protocol.notification( + { + method: 'test/notification', + params: { + data: 'test', + _meta: { + customField: 'customValue', + anotherField: 123 + } + } + }, + { + relatedTask: { taskId: 'task-preserve-meta' } + } + ); + + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + const queue = queues.get('task-preserve-meta')!; + const queuedMessage = queue!.dequeue(); + + // Verify both existing and new metadata are preserved + expect(queuedMessage!.message.params!._meta!.customField).toBe('customValue'); + expect(queuedMessage!.message.params!._meta!.anotherField).toBe(123); + expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ + taskId: 'task-preserve-meta' + }); + }); + }); +}); + +describe('Queue lifecycle management', () => { + let protocol: Protocol; + let transport: MockTransport; + let mockTaskStore: TaskStore & { [K in keyof TaskStore]: MockInstance }; + + beforeEach(() => { + transport = new MockTransport(); + mockTaskStore = createMockTaskStore(); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + }); + + describe('queue cleanup on task completion', () => { + it('should clear queue when task reaches completed status', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue some messages for the task + await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); + await protocol.notification({ method: 'test/notification', params: { data: 'test2' } }, { relatedTask: { taskId } }); + + // Verify messages are queued + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.has(taskId)).toBe(true); + expect(queues.get(taskId)!.size()).toBe(2); + + // Directly call the cleanup method (simulating what happens when task reaches terminal status) + (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + + // Verify queue is cleared + expect(queues.has(taskId)).toBe(false); + }); + + it('should clear queue after delivering messages on tasks/result for completed task', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue a message + await protocol.notification({ method: 'test/notification', params: { data: 'test' } }, { relatedTask: { taskId } }); + + // Mark task as completed + const completedTask = { ...task, status: 'completed' as const }; + mockTaskStore.getTask.mockResolvedValue(completedTask); + mockTaskStore.getTaskResult.mockResolvedValue({ content: [{ type: 'text', text: 'done' }] }); + + // Simulate tasks/result request + const resultPromise = new Promise(resolve => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: 100, + method: 'tasks/result', + params: { taskId } + }); + setTimeout(resolve, 50); + }); + + await resultPromise; + + // Verify queue is cleared after delivery + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.has(taskId)).toBe(false); + }); + }); + + describe('queue cleanup on task cancellation', () => { + it('should clear queue when task is cancelled', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue some messages + await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); + + // Verify message is queued + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.has(taskId)).toBe(true); + expect(queues.get(taskId)!.size()).toBe(1); + + // Mock task as non-terminal + mockTaskStore.getTask.mockResolvedValue(task); + + // Cancel the task + transport.onmessage?.({ + jsonrpc: '2.0', + id: 200, + method: 'tasks/cancel', + params: { taskId } + }); + + // Wait for cancellation to process + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify queue is cleared + expect(queues.has(taskId)).toBe(false); + }); + + it('should reject pending request resolvers when task is cancelled', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue a request (catch rejection to avoid unhandled promise rejection) + const requestPromise = protocol + .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { + relatedTask: { taskId } + }) + .catch(err => err); + + // Verify request is queued + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.has(taskId)).toBe(true); + expect(queues.get(taskId)!.size()).toBe(1); + + // Mock task as non-terminal + mockTaskStore.getTask.mockResolvedValue(task); + + // Cancel the task + transport.onmessage?.({ + jsonrpc: '2.0', + id: 201, + method: 'tasks/cancel', + params: { taskId } + }); + + // Wait for cancellation to process + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify the request promise is rejected + const result = await requestPromise; + expect(result).toBeInstanceOf(McpError); + expect(result.message).toContain('Task cancelled or completed'); + + // Verify queue is cleared + expect(queues.has(taskId)).toBe(false); + }); + }); + + describe('queue cleanup on task failure', () => { + it('should clear queue when task reaches failed status', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue some messages + await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); + await protocol.notification({ method: 'test/notification', params: { data: 'test2' } }, { relatedTask: { taskId } }); + + // Verify messages are queued + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.has(taskId)).toBe(true); + expect(queues.get(taskId)!.size()).toBe(2); + + // Directly call the cleanup method (simulating what happens when task reaches terminal status) + (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + + // Verify queue is cleared + expect(queues.has(taskId)).toBe(false); + }); + + it('should reject pending request resolvers when task fails', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue a request (catch the rejection to avoid unhandled promise rejection) + const requestPromise = protocol + .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { + relatedTask: { taskId } + }) + .catch(err => err); + + // Verify request is queued + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.has(taskId)).toBe(true); + + // Directly call the cleanup method (simulating what happens when task reaches terminal status) + (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + + // Verify the request promise is rejected + const result = await requestPromise; + expect(result).toBeInstanceOf(McpError); + expect(result.message).toContain('Task cancelled or completed'); + + // Verify queue is cleared + expect(queues.has(taskId)).toBe(false); + }); + }); + + describe('resolver rejection on cleanup', () => { + it('should reject all pending request resolvers when queue is cleared', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue multiple requests (catch rejections to avoid unhandled promise rejections) + const request1Promise = protocol + .request({ method: 'test/request1', params: { data: 'test1' } }, z.object({ result: z.string() }), { + relatedTask: { taskId } + }) + .catch(err => err); + + const request2Promise = protocol + .request({ method: 'test/request2', params: { data: 'test2' } }, z.object({ result: z.string() }), { + relatedTask: { taskId } + }) + .catch(err => err); + + const request3Promise = protocol + .request({ method: 'test/request3', params: { data: 'test3' } }, z.object({ result: z.string() }), { + relatedTask: { taskId } + }) + .catch(err => err); + + // Verify requests are queued + const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + expect(queues.has(taskId)).toBe(true); + expect(queues.get(taskId)!.size()).toBe(3); + + // Directly call the cleanup method (simulating what happens when task reaches terminal status) + (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + + // Verify all request promises are rejected + const result1 = await request1Promise; + const result2 = await request2Promise; + const result3 = await request3Promise; + + expect(result1).toBeInstanceOf(McpError); + expect(result1.message).toContain('Task cancelled or completed'); + expect(result2).toBeInstanceOf(McpError); + expect(result2.message).toContain('Task cancelled or completed'); + expect(result3).toBeInstanceOf(McpError); + expect(result3.message).toContain('Task cancelled or completed'); + + // Verify queue is cleared + expect(queues.has(taskId)).toBe(false); + }); + + it('should clean up resolver mappings when rejecting requests', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue a request (catch rejection to avoid unhandled promise rejection) + const requestPromise = protocol + .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { + relatedTask: { taskId } + }) + .catch(err => err); + + // Get the request ID that was sent + const requestResolvers = (protocol as unknown as TestProtocol)._requestResolvers; + const initialResolverCount = requestResolvers.size; + expect(initialResolverCount).toBeGreaterThan(0); + + // Complete the task (triggers cleanup) + const completedTask = { ...task, status: 'completed' as const }; + mockTaskStore.getTask.mockResolvedValue(completedTask); + + // Directly call the cleanup method (simulating what happens when task reaches terminal status) + (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + + // Verify request promise is rejected + const result = await requestPromise; + expect(result).toBeInstanceOf(McpError); + expect(result.message).toContain('Task cancelled or completed'); + + // Verify resolver mapping is cleaned up + // The resolver should be removed from the map + expect(requestResolvers.size).toBeLessThan(initialResolverCount); + }); + }); +}); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 3b0d3b71e..27d76f568 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -47,6 +47,77 @@ import { AuthInfo } from '../server/auth/types.js'; import { PendingRequest } from './request.js'; import { isTerminal, TaskStore } from './task.js'; +/** + * Represents a message queued for side-channel delivery via tasks/result. + */ +export interface QueuedMessage { + /** Type of message */ + type: 'request' | 'notification'; + /** The actual JSONRPC message */ + message: JSONRPCRequest | JSONRPCNotification; + /** When it was queued */ + timestamp: number; + /** For requests: resolver to call when response is received */ + responseResolver?: (response: JSONRPCResponse | Error) => void; + /** For requests: the original request ID for response routing */ + originalRequestId?: RequestId; +} + +/** + * A per-task FIFO queue for server-initiated messages that will be delivered + * through the tasks/result response stream. + */ +export class TaskMessageQueue { + private messages: QueuedMessage[] = []; + + /** + * Adds a message to the end of the queue. + * @param message The message to enqueue + */ + enqueue(message: QueuedMessage): void { + this.messages.push(message); + } + + /** + * Removes and returns the first message from the queue. + * @returns The first message, or undefined if the queue is empty + */ + dequeue(): QueuedMessage | undefined { + return this.messages.shift(); + } + + /** + * Removes and returns all messages from the queue. + * @returns Array of all messages that were in the queue + */ + dequeueAll(): QueuedMessage[] { + const allMessages = this.messages; + this.messages = []; + return allMessages; + } + + /** + * Removes all messages from the queue. + */ + clear(): void { + this.messages = []; + } + + /** + * Returns the number of messages in the queue. + */ + size(): number { + return this.messages.length; + } + + /** + * Checks if the queue is empty. + */ + isEmpty(): boolean { + return this.messages.length === 0; + } +} + /** * Callback for progress notifications. */ @@ -81,6 +152,12 @@ export type ProtocolOptions = { * is provided by the server. Defaults to 5000ms if not specified. */ defaultTaskPollInterval?: number; + /** + * Maximum number of messages that can be queued per task for side-channel delivery. + * If undefined, the queue size is unbounded. + * When the limit is exceeded, the task will be transitioned to failed status. + */ + maxTaskQueueSize?: number; }; /** @@ -306,6 +383,11 @@ export abstract class Protocol = new Map(); + private _taskResultWaiters: Map void>> = new Map(); + private _requestResolvers: Map void> = new Map(); + /** * Callback for when the connection is closed for any reason. * @@ -363,62 +445,79 @@ export abstract class Protocol { - // Helper function to wait with abort signal support - const waitWithAbort = (ms: number, signal: AbortSignal): Promise => { - return new Promise((resolve, reject) => { - if (signal.aborted) { - reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled while waiting for task completion')); - return; + const handleTaskResult = async (): Promise => { + const taskId = request.params.taskId; + const queue = this._taskMessageQueues.get(taskId); + + // Deliver queued messages + if (queue && !queue.isEmpty()) { + while (!queue.isEmpty()) { + const queuedMessage = queue.dequeue()!; + + // Send the message on the response stream by passing the relatedRequestId + // This tells the transport to write the message to the tasks/result response stream + await this._transport?.send(queuedMessage.message, { relatedRequestId: extra.requestId }); + + // If it was a request, wait for the response before delivering the next message + if (queuedMessage.type === 'request' && queuedMessage.responseResolver) { + // Wait for response before continuing to next message + await new Promise((resolve, reject) => { + const originalResolver = queuedMessage.responseResolver!; + const wrappedResolver = (response: JSONRPCResponse | Error) => { + // First, deliver the response to the task handler + originalResolver(response); + // Then, signal that we can continue delivering messages + if (response instanceof Error) { + reject(response); + } else { + resolve(); + } + }; + // Replace the resolver so _onresponse calls our wrapped version + if (queuedMessage.originalRequestId !== undefined) { + this._requestResolvers.set(queuedMessage.originalRequestId, wrappedResolver); + } + }); + } } + } + + // Now check task status + const task = await this._taskStore!.getTask(taskId, extra.sessionId); + if (!task) { + throw new McpError(ErrorCode.InvalidParams, `Task not found: ${taskId}`); + } - const timeoutId = setTimeout(() => { - signal.removeEventListener('abort', abortHandler); - resolve(); - }, ms); + // Block if task is not terminal and no messages to deliver + if (!isTerminal(task.status) && (!queue || queue.isEmpty())) { + // Wait for status change or new messages + await this._waitForTaskUpdate(taskId, extra.signal); - const abortHandler = () => { - clearTimeout(timeoutId); - reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled while waiting for task completion')); - }; + // After waking up, recursively call to deliver any new messages or result + return await handleTaskResult(); + } - signal.addEventListener('abort', abortHandler, { once: true }); - }); - }; + // If task is terminal, return the result + if (isTerminal(task.status)) { + const result = await this._taskStore!.getTaskResult(taskId, extra.sessionId); - const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); - if (!task) { - throw new McpError(ErrorCode.InvalidParams, `Task not found: ${request.params.taskId}`); - } + this._clearTaskQueue(taskId); - // If task is not in a terminal state, block until it reaches one - if (!isTerminal(task.status)) { - // Poll for task completion - let currentTask = task; - while (!isTerminal(currentTask.status)) { - // Wait for the poll interval before checking again - await waitWithAbort(currentTask.pollInterval ?? 5000, extra.signal); - - // Get updated task status - const updatedTask = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); - if (!updatedTask) { - throw new McpError(ErrorCode.InvalidParams, `Task not found: ${request.params.taskId}`); - } - currentTask = updatedTask; + return { + ...result, + _meta: { + ...result._meta, + [RELATED_TASK_META_KEY]: { + taskId: taskId + } + } + } as SendResultT; } - } - // Task is now in a terminal state (completed, failed, or cancelled) - // Retrieve and return the result - const result = await this._taskStore!.getTaskResult(request.params.taskId, extra.sessionId); - return { - ...result, - _meta: { - ...result._meta, - [RELATED_TASK_META_KEY]: { - taskId: request.params.taskId - } - } - } as SendResultT; + return await handleTaskResult(); + }; + + return await handleTaskResult(); }); this.setRequestHandler(ListTasksRequestSchema, async (request, extra) => { @@ -458,6 +557,9 @@ export abstract class Protocol { - this._cleanupTimeout(messageId); - reject(error); - }); + // Queue request if related to a task + const relatedTaskId = relatedTask?.taskId; + if (relatedTaskId) { + // Store the response resolver for this request so responses can be routed back + const responseResolver = (response: JSONRPCResponse | Error) => { + const handler = this._responseHandlers.get(messageId); + if (handler) { + handler(response); + } else { + // Log error when resolver is missing, but don't fail + this._onerror(new Error(`Response handler missing for side-channeled request ${messageId}`)); + } + }; + this._requestResolvers.set(messageId, responseResolver); + + try { + this._enqueueTaskMessage(relatedTaskId, { + type: 'request', + message: jsonrpcRequest, + timestamp: Date.now(), + responseResolver: responseResolver, + originalRequestId: messageId + }); + + // Notify any waiting tasks/result calls + this._notifyTaskResultWaiters(relatedTaskId); + } catch (error) { + this._cleanupTimeout(messageId); + reject(error); + return; + } + + // Try sending through transport (will be delivered via queue if transport fails) + try { + this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { + this._onerror( + new Error( + `Transport send failed for queued message (this is expected for unidirectional transports): ${error.message}` + ) + ); + }); + } catch (error) { + this._onerror( + new Error( + `Transport send failed synchronously for queued message (this is expected for unidirectional transports): ${error instanceof Error ? error.message : String(error)}` + ) + ); + } + } else { + // No related task - send through transport normally + this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { + this._cleanupTimeout(messageId); + reject(error); + }); + } }); return new PendingRequest( @@ -1003,6 +1170,50 @@ export abstract class Protocol { + this._onerror( + new Error( + `Transport send failed for queued notification (this is expected for unidirectional transports): ${error.message}` + ) + ); + }); + } catch (error) { + this._onerror( + new Error( + `Transport send failed synchronously for queued notification (this is expected for unidirectional transports): ${error instanceof Error ? error.message : String(error)}` + ) + ); + } + return; + } + const debouncedMethods = this._options?.debouncedNotificationMethods ?? []; // A notification can only be debounced if it's in the list AND it's "simple" // (i.e., has no parameters and no related request ID or related task that could be lost). @@ -1150,6 +1361,120 @@ export abstract class Protocol= maxQueueSize) { + const errorMessage = `Task message queue overflow: queue size (${queue.size()}) exceeds maximum (${maxQueueSize})`; + + // Log the error for debugging + this._onerror(new Error(errorMessage)); + + this._taskStore?.updateTaskStatus(taskId, 'failed', 'Task message queue overflow').catch(err => this._onerror(err)); + this._clearTaskQueue(taskId); + + throw new McpError(ErrorCode.InternalError, 'Task message queue overflow'); + } + + queue.enqueue(message); + } + + /** + * Clears the message queue for a task and rejects any pending request resolvers. + * @param taskId The task ID whose queue should be cleared + */ + private _clearTaskQueue(taskId: string): void { + const queue = this._taskMessageQueues.get(taskId); + if (queue) { + // Reject any pending request resolvers + for (const message of queue.dequeueAll()) { + if (message.type === 'request' && message.responseResolver && message.originalRequestId !== undefined) { + message.responseResolver(new McpError(ErrorCode.InternalError, 'Task cancelled or completed')); + // Clean up the resolver mapping + this._requestResolvers.delete(message.originalRequestId); + } + } + this._taskMessageQueues.delete(taskId); + } + } + + /** + * Notifies any waiting tasks/result calls that new messages are available or task status changed. + * @param taskId The task ID to notify waiters for + */ + private _notifyTaskResultWaiters(taskId: string): void { + const waiters = this._taskResultWaiters.get(taskId); + if (waiters) { + for (const waiter of waiters) { + waiter(); + } + this._taskResultWaiters.delete(taskId); + } + } + + /** + * Waits for a task update (new messages or status change) with abort signal support. + * This method uses a hybrid approach: + * 1. Primary: Event-driven notifications via _notifyTaskResultWaiters() when messages + * are queued or task status changes + * 2. Fallback: Lightweight polling (every 100ms) to handle edge cases and race conditions + * + * The polling serves as a safety net for scenarios where notifications might be missed + * due to timing issues, but the event-driven approach handles the majority of cases. + * @param taskId The task ID to wait for + * @param signal Abort signal to cancel the wait + * @returns Promise that resolves when an update occurs or rejects if aborted + */ + private async _waitForTaskUpdate(taskId: string, signal: AbortSignal): Promise { + return new Promise((resolve, reject) => { + if (signal.aborted) { + reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); + return; + } + + const waiters = this._taskResultWaiters.get(taskId) || []; + waiters.push(resolve); + this._taskResultWaiters.set(taskId, waiters); + + signal.addEventListener( + 'abort', + () => { + reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); + }, + { once: true } + ); + + // Polling as a fallback mechanism for edge cases and race conditions + // Most updates will be handled by event-driven notifications via _notifyTaskResultWaiters() + const pollInterval = setInterval(async () => { + try { + const task = await this._taskStore?.getTask(taskId); + if (task && (isTerminal(task.status) || this._taskMessageQueues.get(taskId)?.size())) { + clearInterval(pollInterval); + this._notifyTaskResultWaiters(taskId); + } + } catch { + // Ignore errors during polling + } + }, 100); + + // Clean up the interval when the promise resolves or rejects + const cleanup = () => clearInterval(pollInterval); + signal.addEventListener('abort', cleanup, { once: true }); + }); + } + private requestTaskStore(request?: JSONRPCRequest, sessionId?: string): RequestTaskStore { const taskStore = this._taskStore; if (!taskStore) { @@ -1196,6 +1521,7 @@ export abstract class Protocol Date: Thu, 20 Nov 2025 13:58:55 -0800 Subject: [PATCH 53/84] Replace PendingRequest with async generator --- src/client/index.test.ts | 246 ++++++- src/client/index.ts | 174 ++++- src/examples/client/simpleOAuthClient.ts | 88 ++- src/examples/client/simpleStreamableHttp.ts | 52 +- src/integration-tests/taskLifecycle.test.ts | 9 +- src/server/index.test.ts | 93 ++- src/server/index.ts | 43 ++ src/shared/protocol.test.ts | 722 ++++++++++++++++---- src/shared/protocol.ts | 188 +++-- src/shared/request.test.ts | 253 ------- src/shared/request.ts | 87 --- src/shared/responseMessage.ts | 47 ++ 12 files changed, 1356 insertions(+), 646 deletions(-) delete mode 100644 src/shared/request.test.ts delete mode 100644 src/shared/request.ts create mode 100644 src/shared/responseMessage.ts diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 9b97b8e3b..8eaf0ccc5 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -17,7 +17,8 @@ import { CreateMessageRequestSchema, ElicitRequestSchema, ListRootsRequestSchema, - ErrorCode + ErrorCode, + McpError } from '../types.js'; import { Transport } from '../shared/transport.js'; import { Server } from '../server/index.js'; @@ -719,8 +720,8 @@ test('should handle client cancelling a request', async () => { }); controller.abort('Cancelled by test'); - // Request should be rejected - await expect(listResourcesPromise).rejects.toBe('Cancelled by test'); + // Request should be rejected with an McpError + await expect(listResourcesPromise).rejects.toThrow(McpError); }); /*** @@ -1412,14 +1413,12 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Client creates task on server via tool call - const pendingRequest = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + await client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 } }); - await pendingRequest.result(); - // Verify task was created successfully by listing tasks const taskList = await client.listTasks(); expect(taskList.tasks.length).toBeGreaterThan(0); @@ -1498,10 +1497,9 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create a task - const pending = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + await client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 } }); - await pending.result(); // Query task status by listing tasks and getting the first one const taskList = await client.listTasks(); @@ -1584,10 +1582,9 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create a task - const pending = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + await client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 } }); - await pending.result(); // Get the task ID from the task list and query task result const taskList = await client.listTasks(); @@ -1670,10 +1667,9 @@ describe('Task-based execution', () => { const createdTaskIds: string[] = []; for (let i = 0; i < 2; i++) { - const pending = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + await client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 } }); - await pending.result(); // Get the task ID from the task list const taskList = await client.listTasks(); @@ -1781,7 +1777,7 @@ describe('Task-based execution', () => { content: z.record(z.unknown()).optional() }); - const pendingRequest = server.beginRequest( + await server.request( { method: 'elicitation/create', params: { @@ -1799,8 +1795,6 @@ describe('Task-based execution', () => { { task: { ttl: 60000 } } ); - await pendingRequest.result(); - // Get the task ID from the task list since it's generated automatically const taskList = await server.listTasks(); expect(taskList.tasks.length).toBeGreaterThan(0); @@ -1884,7 +1878,7 @@ describe('Task-based execution', () => { content: z.record(z.unknown()).optional() }); - const pending = server.beginRequest( + const pending = server.request( { method: 'elicitation/create', params: { @@ -1898,7 +1892,6 @@ describe('Task-based execution', () => { ElicitResultSchema, { task: { ttl: 60000 } } ); - await pending.result(); // Get the task ID from the task list since it's generated automatically const taskList = await server.listTasks(); @@ -1985,7 +1978,7 @@ describe('Task-based execution', () => { content: z.record(z.unknown()).optional() }); - const pending = server.beginRequest( + const pending = server.request( { method: 'elicitation/create', params: { @@ -1999,7 +1992,6 @@ describe('Task-based execution', () => { ElicitResultSchema, { task: { ttl: 60000 } } ); - await pending.result(); // Get the task ID from the task list since it's generated automatically const taskList = await server.listTasks(); @@ -2087,7 +2079,7 @@ describe('Task-based execution', () => { const createdTaskIds: string[] = []; for (let i = 0; i < 2; i++) { - const pending = server.beginRequest( + const pending = server.request( { method: 'elicitation/create', params: { @@ -2101,7 +2093,6 @@ describe('Task-based execution', () => { ElicitResultSchema, { task: { ttl: 60000 } } ); - await pending.result(); // Get the task ID from the task list const taskList = await server.listTasks(); @@ -2213,10 +2204,9 @@ describe('Task-based execution', () => { const createdTaskIds: string[] = []; for (let i = 0; i < 3; i++) { - const pending = client.beginCallTool({ name: 'test-tool', arguments: { id: `task-${i + 1}` } }, CallToolResultSchema, { + await client.callTool({ name: 'test-tool', arguments: { id: `task-${i + 1}` } }, CallToolResultSchema, { task: { ttl: 60000 } }); - await pending.result(); // Get the task ID from the task list const taskList = await client.listTasks(); @@ -2467,10 +2457,11 @@ test('should respect server task capabilities', async () => { }); // These should work because server supports tasks - const pendingRequest = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { - task: { ttl: 60000 } - }); - await expect(pendingRequest.result()).resolves.not.toThrow(); + await expect( + client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { ttl: 60000 } + }) + ).resolves.not.toThrow(); await expect(client.listTasks()).resolves.not.toThrow(); // tools/list doesn't support task creation, but it shouldn't throw - it should just ignore the task metadata @@ -2486,3 +2477,204 @@ test('should respect server task capabilities', async () => { serverTaskStore.cleanup(); }); + +/** + * Test: requestStream() method + */ +test('should expose requestStream() method for streaming responses', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + server.setRequestHandler(CallToolRequestSchema, async () => { + return { + content: [{ type: 'text', text: 'Tool result' }] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // First verify that regular request() works + const regularResult = await client.callTool({ name: 'test-tool', arguments: {} }); + expect(regularResult.content).toEqual([{ type: 'text', text: 'Tool result' }]); + + // Test requestStream with non-task request (should yield only result) + const stream = client.requestStream( + { + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + CallToolResultSchema + ); + + const messages = []; + for await (const message of stream) { + messages.push(message); + } + + // Should have received only a result message (no task messages) + expect(messages.length).toBe(1); + expect(messages[0].type).toBe('result'); + if (messages[0].type === 'result') { + expect(messages[0].result.content).toEqual([{ type: 'text', text: 'Tool result' }]); + } + + await client.close(); + await server.close(); +}); + +/** + * Test: callToolStream() method + */ +test('should expose callToolStream() method for streaming tool calls', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + server.setRequestHandler(CallToolRequestSchema, async () => { + return { + content: [{ type: 'text', text: 'Tool result' }] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Test callToolStream + const stream = client.callToolStream({ name: 'test-tool', arguments: {} }); + + const messages = []; + for await (const message of stream) { + messages.push(message); + } + + // Should have received messages ending with result + expect(messages.length).toBe(1); + expect(messages[0].type).toBe('result'); + if (messages[0].type === 'result') { + expect(messages[0].result.content).toEqual([{ type: 'text', text: 'Tool result' }]); + } + + await client.close(); + await server.close(); +}); + +/** + * Test: callToolStream() with output schema validation + */ +test('should validate structured output in callToolStream()', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + server.setRequestHandler(ListToolsRequestSchema, async () => { + return { + tools: [ + { + name: 'structured-tool', + description: 'A tool with output schema', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + value: { type: 'number' } + }, + required: ['value'] + } + } + ] + }; + }); + + server.setRequestHandler(CallToolRequestSchema, async () => { + return { + content: [{ type: 'text', text: 'Result' }], + structuredContent: { value: 42 } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // List tools to cache the output schema + await client.listTools(); + + // Test callToolStream with valid structured output + const stream = client.callToolStream({ name: 'structured-tool', arguments: {} }); + + const messages = []; + for await (const message of stream) { + messages.push(message); + } + + // Should have received result with validated structured content + expect(messages.length).toBe(1); + expect(messages[0].type).toBe('result'); + if (messages[0].type === 'result') { + expect(messages[0].result.structuredContent).toEqual({ value: 42 }); + } + + await client.close(); + await server.close(); +}); diff --git a/src/client/index.ts b/src/client/index.ts index 6107baf38..cbbf41a9f 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,6 +1,6 @@ import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; import type { Transport } from '../shared/transport.js'; -import { PendingRequest } from '../shared/request.js'; +import { ResponseMessage } from '../shared/responseMessage.js'; import { type CallToolRequest, @@ -528,29 +528,10 @@ export class Client< return this.request({ method: 'resources/unsubscribe', params }, EmptyResultSchema, options); } - /** - * Begins a tool call and returns a PendingRequest for granular control over task-based execution. - * - * This is useful when you want to create a task for a long-running tool call and poll for results later. - */ - beginCallTool( - params: CallToolRequest['params'], - resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, - options?: RequestOptions - ): PendingRequest { - // Add task creation parameters if server supports it and not explicitly provided - const optionsWithTask = { - ...options, - // We check the server capabilities in auto-assignment, but assume the caller knows what they're doing if they pass this explicitly - task: options?.task ?? (this._serverCapabilities?.tasks?.requests?.tools?.call ? {} : undefined) - }; - return this.beginRequest({ method: 'tools/call', params }, resultSchema, optionsWithTask); - } - /** * Calls a tool and waits for the result. Automatically validates structured output if the tool has an outputSchema. * - * For task-based execution with granular control, use beginCallTool() instead. + * For task-based execution with streaming behavior, use callToolStream() instead. */ async callTool( params: CallToolRequest['params'], @@ -632,4 +613,155 @@ export class Client< async sendRootsListChanged() { return this.notification({ method: 'notifications/roots/list_changed' }); } + + /** + * Sends a request and returns an AsyncGenerator that yields response messages. + * The generator is guaranteed to end with either a 'result' or 'error' message. + * + * This method provides streaming access to request processing, allowing you to + * observe intermediate task status updates for task-augmented requests. + * + * @example + * ```typescript + * const stream = client.requestStream(request, resultSchema, options); + * for await (const message of stream) { + * switch (message.type) { + * case 'taskCreated': + * console.log('Task created:', message.task.taskId); + * break; + * case 'taskStatus': + * console.log('Task status:', message.task.status); + * break; + * case 'result': + * console.log('Final result:', message.result); + * break; + * case 'error': + * console.error('Error:', message.error); + * break; + * } + * } + * ``` + * + * @param request - The request to send + * @param resultSchema - Zod schema for validating the result + * @param options - Optional request options (timeout, signal, task creation params, etc.) + * @returns AsyncGenerator that yields ResponseMessage objects + */ + requestStream>( + request: ClientRequest | RequestT, + resultSchema: T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + return super.requestStream(request, resultSchema, options); + } + + /** + * Calls a tool and returns an AsyncGenerator that yields response messages. + * The generator is guaranteed to end with either a 'result' or 'error' message. + * + * This method provides streaming access to tool execution, allowing you to + * observe intermediate task status updates for long-running tool calls. + * Automatically validates structured output if the tool has an outputSchema. + * + * For simple tool calls without streaming, use callTool() instead. + * + * @example + * ```typescript + * const stream = client.callToolStream({ name: 'myTool', arguments: {} }); + * for await (const message of stream) { + * switch (message.type) { + * case 'taskCreated': + * console.log('Tool execution started:', message.task.taskId); + * break; + * case 'taskStatus': + * console.log('Tool status:', message.task.status); + * break; + * case 'result': + * console.log('Tool result:', message.result); + * // Structured output is automatically validated + * break; + * case 'error': + * console.error('Tool error:', message.error); + * break; + * } + * } + * ``` + * + * @param params - Tool call parameters (name and arguments) + * @param resultSchema - Zod schema for validating the result (defaults to CallToolResultSchema) + * @param options - Optional request options (timeout, signal, task creation params, etc.) + * @returns AsyncGenerator that yields ResponseMessage objects + */ + async *callToolStream( + params: CallToolRequest['params'], + resultSchema: T = CallToolResultSchema as T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + // Add task creation parameters if server supports it and not explicitly provided + const optionsWithTask = { + ...options, + // We check the server capabilities in auto-assignment, but assume the caller knows what they're doing if they pass this explicitly + task: options?.task ?? (this._serverCapabilities?.tasks?.requests?.tools?.call ? {} : undefined) + }; + + const stream = this.requestStream({ method: 'tools/call', params }, resultSchema, optionsWithTask); + + // Get the validator for this tool (if it has an output schema) + const validator = this.getToolOutputValidator(params.name); + + // Iterate through the stream and validate the final result if needed + for await (const message of stream) { + // If this is a result message and the tool has an output schema, validate it + if (message.type === 'result' && validator) { + const result = message.result; + + // If tool has outputSchema, it MUST return structuredContent (unless it's an error) + if (!result.structuredContent && !result.isError) { + yield { + type: 'error', + error: new McpError( + ErrorCode.InvalidRequest, + `Tool ${params.name} has an output schema but did not return structured content` + ) + }; + return; + } + + // Only validate structured content if present (not when there's an error) + if (result.structuredContent) { + try { + // Validate the structured content against the schema + const validationResult = validator(result.structuredContent); + + if (!validationResult.valid) { + yield { + type: 'error', + error: new McpError( + ErrorCode.InvalidParams, + `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` + ) + }; + return; + } + } catch (error) { + if (error instanceof McpError) { + yield { type: 'error', error }; + return; + } + yield { + type: 'error', + error: new McpError( + ErrorCode.InvalidParams, + `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` + ) + }; + return; + } + } + } + + // Yield the message (either validated result or any other message type) + yield message; + } + } } diff --git a/src/examples/client/simpleOAuthClient.ts b/src/examples/client/simpleOAuthClient.ts index fc296bc6a..2299d0fea 100644 --- a/src/examples/client/simpleOAuthClient.ts +++ b/src/examples/client/simpleOAuthClient.ts @@ -254,6 +254,7 @@ class InteractiveOAuthClient { console.log('Commands:'); console.log(' list - List available tools'); console.log(' call [args] - Call a tool'); + console.log(' stream [args] - Call a tool with streaming (shows task status)'); console.log(' quit - Exit the client'); console.log(); @@ -273,8 +274,10 @@ class InteractiveOAuthClient { await this.listTools(); } else if (command.startsWith('call ')) { await this.handleCallTool(command); + } else if (command.startsWith('stream ')) { + await this.handleStreamTool(command); } else { - console.log("❌ Unknown command. Try 'list', 'call ', or 'quit'"); + console.log("❌ Unknown command. Try 'list', 'call ', 'stream ', or 'quit'"); } } catch (error) { if (error instanceof Error && error.message === 'SIGINT') { @@ -375,6 +378,89 @@ class InteractiveOAuthClient { } } + private async handleStreamTool(command: string): Promise { + const parts = command.split(/\s+/); + const toolName = parts[1]; + + if (!toolName) { + console.log('❌ Please specify a tool name'); + return; + } + + // Parse arguments (simple JSON-like format) + let toolArgs: Record = {}; + if (parts.length > 2) { + const argsString = parts.slice(2).join(' '); + try { + toolArgs = JSON.parse(argsString); + } catch { + console.log('❌ Invalid arguments format (expected JSON)'); + return; + } + } + + await this.streamTool(toolName, toolArgs); + } + + private async streamTool(toolName: string, toolArgs: Record): Promise { + if (!this.client) { + console.log('❌ Not connected to server'); + return; + } + + try { + console.log(`\n🔧 Streaming tool '${toolName}'...`); + + const stream = this.client.callToolStream( + { + name: toolName, + arguments: toolArgs + }, + CallToolResultSchema, + { + task: { + taskId: `task-${Date.now()}`, + ttl: 60000 + } + } + ); + + // Iterate through all messages yielded by the generator + for await (const message of stream) { + switch (message.type) { + case 'taskCreated': + console.log(`✓ Task created: ${message.task.taskId}`); + break; + + case 'taskStatus': + console.log(`⟳ Status: ${message.task.status}`); + if (message.task.statusMessage) { + console.log(` ${message.task.statusMessage}`); + } + break; + + case 'result': + console.log('✓ Completed!'); + message.result.content.forEach(content => { + if (content.type === 'text') { + console.log(content.text); + } else { + console.log(content); + } + }); + break; + + case 'error': + console.log('✗ Error:'); + console.log(` ${message.error.message}`); + break; + } + } + } catch (error) { + console.error(`❌ Failed to stream tool '${toolName}':`, error); + } + } + close(): void { this.rl.close(); if (this.client) { diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index c55e105b2..d1e4dfe9e 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -807,13 +807,11 @@ async function callToolTask(name: string, args: Record): Promis console.log('Arguments:', args); // Use task-based execution - call now, fetch later - const taskId = `task-${Date.now()}`; - console.log(`Task ID: ${taskId}`); console.log('This will return immediately while processing continues in the background...'); try { - // Begin the tool call with task metadata - const pendingRequest = client.beginCallTool( + // Call the tool with task metadata using streaming API + const stream = client.callToolStream( { name, arguments: args @@ -821,7 +819,6 @@ async function callToolTask(name: string, args: Record): Promis CallToolResultSchema, { task: { - taskId, ttl: 60000 // Keep results for 60 seconds } } @@ -830,29 +827,30 @@ async function callToolTask(name: string, args: Record): Promis console.log('Waiting for task completion...'); let lastStatus = ''; - await pendingRequest.result({ - onTaskCreated: () => { - console.log('Task created successfully'); - }, - onTaskStatus: task => { - if (lastStatus !== task.status) { - console.log(` ${task.status}${task.statusMessage ? ` - ${task.statusMessage}` : ''}`); - } - lastStatus = task.status; - } - }); - - console.log('Task completed! Fetching result...'); - - // Get the actual result - const result = await client.getTaskResult({ taskId }, CallToolResultSchema); - - console.log('Tool result:'); - result.content.forEach(item => { - if (item.type === 'text') { - console.log(` ${item.text}`); + for await (const message of stream) { + switch (message.type) { + case 'taskCreated': + console.log('Task created successfully'); + break; + case 'taskStatus': + if (lastStatus !== message.task.status) { + console.log(` ${message.task.status}${message.task.statusMessage ? ` - ${message.task.statusMessage}` : ''}`); + } + lastStatus = message.task.status; + break; + case 'result': + console.log('Task completed!'); + console.log('Tool result:'); + message.result.content.forEach(item => { + if (item.type === 'text') { + console.log(` ${item.text}`); + } + }); + break; + case 'error': + throw message.error; } - }); + } } catch (error) { console.log(`Error with task-based execution: ${error}`); } diff --git a/src/integration-tests/taskLifecycle.test.ts b/src/integration-tests/taskLifecycle.test.ts index 220fcbb22..4777cc0fe 100644 --- a/src/integration-tests/taskLifecycle.test.ts +++ b/src/integration-tests/taskLifecycle.test.ts @@ -1056,9 +1056,9 @@ describe('Task Lifecycle Integration Tests', () => { expect(task.status).toBe('cancelled'); // Attempt to call tasks/result - // According to Requirement 4.2: "WHEN a task is cancelled THEN the system SHALL - // clear the message queue and reject any pending message delivery promises" - // This means NO messages should be delivered for a cancelled task + // When a task is cancelled, the system needs to clear the message queue + // and reject any pending message delivery promises, meaning no further + // messages should be delivered for a cancelled task. try { await client.request( { @@ -1072,8 +1072,7 @@ describe('Task Lifecycle Integration Tests', () => { // This is acceptable behavior } - // Verify no elicitation messages were delivered - // This validates Property 12: queue should be cleared immediately on cancellation + // Verify no elicitation messages were delivered, as the queue should be cleared immediately on cancellation expect(elicitationCallCount).toBe(0); // Verify queue remains cleared on subsequent calls diff --git a/src/server/index.test.ts b/src/server/index.test.ts index cc2f81c70..a48f2d639 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -13,6 +13,7 @@ import { ListResourcesRequestSchema, ListToolsRequestSchema, type LoggingMessageNotification, + McpError, NotificationSchema, RequestSchema, ResultSchema, @@ -263,7 +264,7 @@ test('should respect client capabilities', async () => { ).resolves.not.toThrow(); // This should still throw because roots are not supported by the client - await expect(server.listRoots()).rejects.toThrow(/^Client does not support/); + await expect(server.listRoots()).rejects.toThrow(/Client does not support/); }); test('should respect client elicitation capabilities', async () => { @@ -345,7 +346,7 @@ test('should respect client elicitation capabilities', async () => { messages: [], maxTokens: 10 }) - ).rejects.toThrow(/^Client does not support/); + ).rejects.toThrow(/Client does not support/); }); test('should validate elicitation response against requested schema', async () => { @@ -753,8 +754,8 @@ test('should handle server cancelling a request', async () => { ); controller.abort('Cancelled by test'); - // Request should be rejected - await expect(createMessagePromise).rejects.toBe('Cancelled by test'); + // Request should be rejected with an McpError + await expect(createMessagePromise).rejects.toThrow(McpError); }); test('should handle request timeout', async () => { @@ -1043,14 +1044,13 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Use beginCallTool to create a task - const pendingRequest = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + await client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 } }); // Wait for the task to complete - await pendingRequest.result(); // Get the task ID from the task list since it's generated automatically const taskList = await client.listTasks(); @@ -1274,14 +1274,13 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Call tool WITH task creation - const pendingRequest = client.beginCallTool({ name: 'collect-info', arguments: {} }, CallToolResultSchema, { + await client.callTool({ name: 'collect-info', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 } }); // Wait for completion - await pendingRequest.result(); // Verify the nested elicitation request was made (related-task metadata is no longer automatically attached) expect(capturedElicitRequest).toBeDefined(); @@ -1375,7 +1374,7 @@ describe('Task-based execution', () => { content: z.record(z.unknown()).optional() }); - const pendingRequest = server.beginRequest( + await server.request( { method: 'elicitation/create', params: { @@ -1394,8 +1393,6 @@ describe('Task-based execution', () => { { task: { ttl: 60000 } } ); - await pendingRequest.result(); - // Get the task ID from the task list since it's generated automatically const taskList = await server.listTasks(); expect(taskList.tasks.length).toBeGreaterThan(0); @@ -1466,7 +1463,7 @@ describe('Task-based execution', () => { content: z.record(z.unknown()).optional() }); - const pending = server.beginRequest( + await server.request( { method: 'elicitation/create', params: { @@ -1480,7 +1477,6 @@ describe('Task-based execution', () => { ElicitResultSchema, { task: { ttl: 60000 } } ); - await pending.result(); // Get the task ID from the task list since it's generated automatically const taskList = await server.listTasks(); @@ -1554,7 +1550,7 @@ describe('Task-based execution', () => { content: z.record(z.unknown()).optional() }); - const pending = server.beginRequest( + await server.request( { method: 'elicitation/create', params: { @@ -1571,7 +1567,6 @@ describe('Task-based execution', () => { ElicitResultSchema, { task: { ttl: 60000 } } ); - await pending.result(); // Get the task ID from the task list since it's generated automatically const taskList = await server.listTasks(); @@ -1659,7 +1654,7 @@ describe('Task-based execution', () => { const createdTaskIds: string[] = []; for (let i = 0; i < 2; i++) { - const pending = server.beginRequest( + await server.request( { method: 'elicitation/create', params: { @@ -1673,7 +1668,6 @@ describe('Task-based execution', () => { ElicitResultSchema, { task: { ttl: 60000 } } ); - await pending.result(); // Get the task ID from the task list const taskList = await server.listTasks(); @@ -1789,13 +1783,13 @@ describe('Task-based execution', () => { // Create multiple tasks concurrently const pendingRequests = Array.from({ length: 4 }, (_, index) => - client.beginCallTool({ name: 'async-tool', arguments: { delay: 10 + index * 5, taskNum: index + 1 } }, CallToolResultSchema, { + client.callTool({ name: 'async-tool', arguments: { delay: 10 + index * 5, taskNum: index + 1 } }, CallToolResultSchema, { task: { ttl: 60000 } }) ); // Wait for all tasks to complete - await Promise.all(pendingRequests.map(p => p.result())); + await Promise.all(pendingRequests); // Get all task IDs from the task list const taskList = await client.listTasks(); @@ -2026,21 +2020,22 @@ test('should respect client task capabilities', async () => { }); // These should work because client supports tasks - const pendingRequest = server.beginRequest( - { - method: 'elicitation/create', - params: { - message: 'Test', - requestedSchema: { - type: 'object', - properties: { username: { type: 'string' } } + await expect( + server.request( + { + method: 'elicitation/create', + params: { + message: 'Test', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } } - } - }, - ElicitResultSchema, - { task: { ttl: 60000 } } - ); - await expect(pendingRequest.result()).resolves.not.toThrow(); + }, + ElicitResultSchema, + { task: { ttl: 60000 } } + ) + ).resolves.not.toThrow(); await expect(server.listTasks()).resolves.not.toThrow(); // Get the task ID from the task list since it's generated automatically @@ -2051,23 +2046,21 @@ test('should respect client task capabilities', async () => { // This should throw because client doesn't support task creation for sampling/createMessage await expect( - server - .beginRequest( - { - method: 'sampling/createMessage', - params: { - messages: [], - maxTokens: 10 - } - }, - z.object({ - model: z.string(), - role: z.string(), - content: z.any() - }), - { task: { taskId: 'test-task-2', keepAlive: 60000 } } - ) - .result() + server.request( + { + method: 'sampling/createMessage', + params: { + messages: [], + maxTokens: 10 + } + }, + z.object({ + model: z.string(), + role: z.string(), + content: z.any() + }), + { task: { taskId: 'test-task-2', keepAlive: 60000 } } + ) ).rejects.toThrow('Client does not support task creation for sampling/createMessage'); clientTaskStore.cleanup(); diff --git a/src/server/index.ts b/src/server/index.ts index 74d277401..9da4702b6 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,4 +1,5 @@ import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; +import { ResponseMessage } from '../shared/responseMessage.js'; import { type ClientCapabilities, type CreateMessageRequest, @@ -33,6 +34,7 @@ import { } from '../types.js'; import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js'; import type { JsonSchemaType, jsonSchemaValidator } from '../validation/types.js'; +import { type ZodType, z } from 'zod'; export type ServerOptions = ProtocolOptions & { /** @@ -379,6 +381,47 @@ export class Server< return this._capabilities; } + /** + * Sends a request and returns an AsyncGenerator that yields response messages. + * The generator is guaranteed to end with either a 'result' or 'error' message. + * + * This method provides streaming access to request processing, allowing you to + * observe intermediate task status updates for task-augmented requests. + * + * @example + * ```typescript + * const stream = server.requestStream(request, resultSchema, options); + * for await (const message of stream) { + * switch (message.type) { + * case 'taskCreated': + * console.log('Task created:', message.task.taskId); + * break; + * case 'taskStatus': + * console.log('Task status:', message.task.status); + * break; + * case 'result': + * console.log('Final result:', message.result); + * break; + * case 'error': + * console.error('Error:', message.error); + * break; + * } + * } + * ``` + * + * @param request - The request to send + * @param resultSchema - Zod schema for validating the result + * @param options - Optional request options (timeout, signal, task creation params, etc.) + * @returns AsyncGenerator that yields ResponseMessage objects + */ + requestStream>( + request: ServerRequest | RequestT, + resultSchema: T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + return super.requestStream(request, resultSchema, options); + } + async ping() { return this.request({ method: 'ping' }, EmptyResultSchema); } diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 20f5a5e72..ebbd32def 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -3,6 +3,7 @@ import { CallToolRequestSchema, ClientCapabilities, ErrorCode, + JSONRPCMessage, McpError, Notification, RELATED_TASK_META_KEY, @@ -14,10 +15,11 @@ import { TaskCreationParams } from '../types.js'; import { Protocol, mergeCapabilities, TaskMessageQueue } from './protocol.js'; -import { Transport } from './transport.js'; +import { Transport, TransportSendOptions } from './transport.js'; import { TaskStore } from './task.js'; import { MockInstance, vi } from 'vitest'; import { JSONRPCResponse, JSONRPCRequest, JSONRPCError } from '../types.js'; +import { ErrorMessage, ResponseMessage } from './responseMessage.js'; // Type helper for accessing private Protocol properties in tests interface TestProtocol { @@ -40,7 +42,7 @@ class MockTransport implements Transport { async close(): Promise { this.onclose?.(); } - async send(_message: unknown): Promise {} + async send(_message: JSONRPCMessage, _options?: TransportSendOptions): Promise {} } function createMockTaskStore(options?: { @@ -116,6 +118,21 @@ function createLatch() { }; } +type AsyncGeneratorValue = T extends AsyncGenerator ? U : never; + +async function toArrayAsync>(it: T): Promise[]> { + const arr: AsyncGeneratorValue[] = []; + for await (const o of it) { + arr.push(o as AsyncGeneratorValue); + } + + return arr; +} + +function assertErrorResponse(o: ResponseMessage): asserts o is ErrorMessage { + expect(o.type).toBe('error'); +} + describe('protocol tests', () => { let protocol: Protocol; let transport: MockTransport; @@ -193,9 +210,14 @@ describe('protocol tests', () => { }); const onProgressMock = vi.fn(); - protocol.request(request, mockSchema, { - onprogress: onProgressMock - }); + // Start request but don't await - we're testing the sent message + void protocol + .request(request, mockSchema, { + onprogress: onProgressMock + }) + .catch(() => { + // May not complete, ignore error + }); expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ @@ -228,9 +250,14 @@ describe('protocol tests', () => { }); const onProgressMock = vi.fn(); - protocol.request(request, mockSchema, { - onprogress: onProgressMock - }); + // Start request but don't await - we're testing the sent message + void protocol + .request(request, mockSchema, { + onprogress: onProgressMock + }) + .catch(() => { + // May not complete, ignore error + }); expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ @@ -263,7 +290,10 @@ describe('protocol tests', () => { result: z.string() }); - protocol.request(request, mockSchema); + // Start request but don't await - we're testing the sent message + void protocol.request(request, mockSchema).catch(() => { + // May not complete, ignore error + }); expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ @@ -291,9 +321,14 @@ describe('protocol tests', () => { }); const onProgressMock = vi.fn(); - protocol.request(request, mockSchema, { - onprogress: onProgressMock - }); + // Start request but don't await - we're testing the sent message + void protocol + .request(request, mockSchema, { + onprogress: onProgressMock + }) + .catch(() => { + // May not complete, ignore error + }); expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ @@ -1056,7 +1091,7 @@ describe('Task-based execution', () => { })(); }); - describe('beginRequest with task metadata', () => { + describe('request with task metadata', () => { it('should include task parameters at top level', async () => { await protocol.connect(transport); @@ -1069,12 +1104,16 @@ describe('Task-based execution', () => { content: z.array(z.object({ type: z.literal('text'), text: z.string() })) }); - protocol.beginRequest(request, resultSchema, { - task: { - ttl: 30000, - pollInterval: 1000 - } - }); + void protocol + .request(request, resultSchema, { + task: { + ttl: 30000, + pollInterval: 1000 + } + }) + .catch(() => { + // May not complete, ignore error + }); expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ @@ -1108,11 +1147,15 @@ describe('Task-based execution', () => { content: z.array(z.object({ type: z.literal('text'), text: z.string() })) }); - protocol.beginRequest(request, resultSchema, { - task: { - ttl: 60000 - } - }); + void protocol + .request(request, resultSchema, { + task: { + ttl: 60000 + } + }) + .catch(() => { + // May not complete, ignore error + }); expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ @@ -1130,7 +1173,7 @@ describe('Task-based execution', () => { ); }); - it('should return PendingRequest object', async () => { + it('should return Promise for task-augmented request', async () => { await protocol.connect(transport); const request = { @@ -1142,14 +1185,14 @@ describe('Task-based execution', () => { content: z.array(z.object({ type: z.literal('text'), text: z.string() })) }); - const pendingRequest = protocol.beginRequest(request, resultSchema, { + const resultPromise = protocol.request(request, resultSchema, { task: { ttl: 30000 } }); - expect(pendingRequest).toBeDefined(); - expect(pendingRequest.taskId).toBeUndefined(); // taskId is generated by receiver, not provided by client + expect(resultPromise).toBeDefined(); + expect(resultPromise).toBeInstanceOf(Promise); }); }); @@ -1164,25 +1207,28 @@ describe('Task-based execution', () => { const resultSchema = z.object({}); - protocol.beginRequest(request, resultSchema, { - relatedTask: { - taskId: 'parent-task-123' - } - }); - - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - params: { - data: 'test', - _meta: { - [RELATED_TASK_META_KEY]: { - taskId: 'parent-task-123' - } - } + // Start the request (don't await completion, just let it send) + void protocol + .request(request, resultSchema, { + relatedTask: { + taskId: 'parent-task-123' } - }), - expect.any(Object) - ); + }) + .catch(() => { + // May not complete, ignore error + }); + + // Wait a bit for the request to be queued + await new Promise(resolve => setTimeout(resolve, 10)); + + // Requests with relatedTask should be queued, not sent via transport + // This prevents duplicate delivery for bidirectional transports + expect(sendSpy).not.toHaveBeenCalled(); + + // Verify the message was queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueues.get('parent-task-123'); + expect(queue).toBeDefined(); + expect(queue!.size()).toBeGreaterThan(0); }); it('should work with notification method', async () => { @@ -1200,21 +1246,20 @@ describe('Task-based execution', () => { } ); - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - method: 'notifications/message', - params: { - level: 'info', - data: 'test message', - _meta: { - [RELATED_TASK_META_KEY]: { - taskId: 'parent-task-456' - } - } - } - }), - expect.any(Object) - ); + // Notifications with relatedTask should be queued, not sent via transport + // This prevents duplicate delivery for bidirectional transports + expect(sendSpy).not.toHaveBeenCalled(); + + // Verify the message was queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueues.get('parent-task-456'); + expect(queue).toBeDefined(); + expect(queue!.size()).toBe(1); + + const queuedMessage = queue!.dequeue(); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.type).toBe('notification'); + expect(queuedMessage?.message.method).toBe('notifications/message'); + expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: 'parent-task-456' }); }); }); @@ -1231,35 +1276,50 @@ describe('Task-based execution', () => { content: z.array(z.object({ type: z.literal('text'), text: z.string() })) }); - protocol.beginRequest(request, resultSchema, { + // Start the request (don't await completion, just let it send) + void protocol + .request(request, resultSchema, { + task: { + ttl: 60000, + pollInterval: 1000 + }, + relatedTask: { + taskId: 'parent-task' + }, + onprogress: vi.fn() + }) + .catch(() => { + // May not complete, ignore error + }); + + // Wait a bit for the request to be queued + await new Promise(resolve => setTimeout(resolve, 10)); + + // Requests with relatedTask should be queued, not sent via transport + // This prevents duplicate delivery for bidirectional transports + expect(sendSpy).not.toHaveBeenCalled(); + + // Verify the message was queued with all metadata combined + const queue = (protocol as unknown as TestProtocol)._taskMessageQueues.get('parent-task'); + expect(queue).toBeDefined(); + expect(queue!.size()).toBeGreaterThan(0); + + const queuedMessage = queue!.dequeue(); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.type).toBe('request'); + expect(queuedMessage?.message.params).toMatchObject({ + name: 'test-tool', task: { ttl: 60000, pollInterval: 1000 }, - relatedTask: { - taskId: 'parent-task' - }, - onprogress: vi.fn() + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: 'parent-task' + }, + progressToken: expect.any(Number) + } }); - - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - params: { - name: 'test-tool', - task: { - ttl: 60000, - pollInterval: 1000 - }, - _meta: { - [RELATED_TASK_META_KEY]: { - taskId: 'parent-task' - }, - progressToken: expect.any(Number) - } - } - }), - expect.any(Object) - ); }); }); @@ -1996,7 +2056,15 @@ describe('Task-based execution', () => { }); // Send a request with related-task metadata - serverTransport.onmessage?.({ + let handlerPromise: Promise | undefined; + const originalOnMessage = serverTransport.onmessage; + + serverTransport.onmessage = message => { + handlerPromise = Promise.resolve(originalOnMessage?.(message)); + return handlerPromise; + }; + + serverTransport.onmessage({ jsonrpc: '2.0', id: 1, method: 'tools/call', @@ -2010,23 +2078,32 @@ describe('Task-based execution', () => { } }); - // Wait for async processing - await new Promise(resolve => setTimeout(resolve, 50)); + // Wait for handler to complete + if (handlerPromise) { + await handlerPromise; + } + await new Promise(resolve => setTimeout(resolve, 100)); - // Verify the notification includes related-task metadata - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - method: 'notifications/message', - params: expect.objectContaining({ - _meta: expect.objectContaining({ - [RELATED_TASK_META_KEY]: { - taskId: 'parent-task-123' - } - }) - }) - }), - expect.any(Object) - ); + // Verify the notification was QUEUED (not sent via transport) + // Messages with relatedTask metadata should be queued for delivery via tasks/result + // to prevent duplicate delivery for bidirectional transports + const queues = (serverProtocol as unknown as TestProtocol)._taskMessageQueues; + expect(queues.has('parent-task-123')).toBe(true); + + const queue = queues.get('parent-task-123')!; + expect(queue.size()).toBeGreaterThan(0); + + const queuedMessage = queue.dequeue(); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.type).toBe('notification'); + expect(queuedMessage?.message.method).toBe('notifications/message'); + expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ + taskId: 'parent-task-123' + }); + + // Verify the notification was NOT sent via transport (should be queued instead) + const notificationCalls = sendSpy.mock.calls.filter(call => 'method' in call[0] && call[0].method === 'notifications/message'); + expect(notificationCalls).toHaveLength(0); }); }); }); @@ -2371,10 +2448,17 @@ describe('Progress notification support for tasks', () => { }); // Start a task-augmented request with progress callback - protocol.beginRequest(request, resultSchema, { - task: { ttl: 60000 }, - onprogress: progressCallback - }); + void protocol + .request(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }) + .catch(() => { + // May not complete, ignore error + }); + + // Wait a bit for the request to be sent + await new Promise(resolve => setTimeout(resolve, 10)); // Get the message ID from the sent request const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; @@ -2474,10 +2558,17 @@ describe('Progress notification support for tasks', () => { }); // Start a task-augmented request with progress callback - protocol.beginRequest(request, resultSchema, { - task: { ttl: 60000 }, - onprogress: progressCallback - }); + void protocol + .request(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }) + .catch(() => { + // May not complete, ignore error + }); + + // Wait a bit for the request to be sent + await new Promise(resolve => setTimeout(resolve, 10)); const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; const messageId = sentRequest.id; @@ -2584,7 +2675,7 @@ describe('Progress notification support for tasks', () => { }) }); - protocol.beginRequest(request, resultSchema, { + void protocol.request(request, resultSchema, { task: { ttl: 60000 }, onprogress: progressCallback }); @@ -2683,7 +2774,7 @@ describe('Progress notification support for tasks', () => { }) }); - protocol.beginRequest(request, resultSchema, { + void protocol.request(request, resultSchema, { task: { ttl: 60000 }, onprogress: progressCallback }); @@ -2779,7 +2870,7 @@ describe('Progress notification support for tasks', () => { }) }); - protocol.beginRequest(request, resultSchema, { + void protocol.request(request, resultSchema, { task: { ttl: 60000 }, onprogress: progressCallback }); @@ -2850,7 +2941,7 @@ describe('Progress notification support for tasks', () => { const onProgressMock = vi.fn(); - protocol.beginRequest(request, resultSchema, { + void protocol.request(request, resultSchema, { task: { ttl: 60000 }, @@ -2875,7 +2966,7 @@ describe('Progress notification support for tasks', () => { const onProgressMock = vi.fn(); - protocol.beginRequest(request, resultSchema, { + void protocol.request(request, resultSchema, { task: { ttl: 30000 }, @@ -2925,7 +3016,7 @@ describe('Progress notification support for tasks', () => { const onProgressMock = vi.fn(); - protocol.beginRequest(request, resultSchema, { + void protocol.request(request, resultSchema, { task: { ttl: 30000 }, @@ -4262,3 +4353,402 @@ describe('Queue lifecycle management', () => { }); }); }); + +describe('requestStream() method', () => { + const CallToolResultSchema = z.object({ + content: z.array(z.object({ type: z.string(), text: z.string() })), + _meta: z.object({}).optional() + }); + + test('should yield result immediately for non-task requests', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + // Start the request stream + const streamPromise = (async () => { + const messages = []; + const stream = protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema); + for await (const message of stream) { + messages.push(message); + } + return messages; + })(); + + // Simulate server response + await new Promise(resolve => setTimeout(resolve, 10)); + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + result: { + content: [{ type: 'text', text: 'test result' }], + _meta: {} + } + }); + + const messages = await streamPromise; + + // Should yield exactly one result message + expect(messages).toHaveLength(1); + expect(messages[0].type).toBe('result'); + expect(messages[0]).toHaveProperty('result'); + }); + + test('should yield error message on request failure', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + // Start the request stream + const streamPromise = (async () => { + const messages = []; + const stream = protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema); + for await (const message of stream) { + messages.push(message); + } + return messages; + })(); + + // Simulate server error response + await new Promise(resolve => setTimeout(resolve, 10)); + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + error: { + code: ErrorCode.InternalError, + message: 'Test error' + } + }); + + const messages = await streamPromise; + + // Should yield exactly one error message + expect(messages).toHaveLength(1); + expect(messages[0].type).toBe('error'); + expect(messages[0]).toHaveProperty('error'); + if (messages[0].type === 'error') { + expect(messages[0].error.message).toContain('Test error'); + } + }); + + test('should handle cancellation via AbortSignal', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + const abortController = new AbortController(); + + // Abort immediately before starting the stream + abortController.abort('User cancelled'); + + // Start the request stream with already-aborted signal + const messages = []; + const stream = protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { + signal: abortController.signal + }); + for await (const message of stream) { + messages.push(message); + } + + // Should yield error message about cancellation + expect(messages).toHaveLength(1); + expect(messages[0].type).toBe('error'); + if (messages[0].type === 'error') { + expect(messages[0].error.message).toContain('cancelled'); + } + }); + + describe('Error responses', () => { + test('should yield error as terminal message for server error response', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + const messagesPromise = toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + ); + + // Simulate server error response + await new Promise(resolve => setTimeout(resolve, 10)); + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + error: { + code: ErrorCode.InternalError, + message: 'Server error' + } + }); + + // Collect messages + const messages = await messagesPromise; + + // Verify error is terminal and last message + expect(messages.length).toBeGreaterThan(0); + const lastMessage = messages[messages.length - 1]; + assertErrorResponse(lastMessage); + expect(lastMessage.error).toBeDefined(); + expect(lastMessage.error.message).toContain('Server error'); + }); + + test('should yield error as terminal message for timeout', async () => { + vi.useFakeTimers(); + try { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + const messagesPromise = toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { + timeout: 100 + }) + ); + + // Advance time to trigger timeout + await vi.advanceTimersByTimeAsync(101); + + // Collect messages + const messages = await messagesPromise; + + // Verify error is terminal and last message + expect(messages.length).toBeGreaterThan(0); + const lastMessage = messages[messages.length - 1]; + assertErrorResponse(lastMessage); + expect(lastMessage.error).toBeDefined(); + expect(lastMessage.error.code).toBe(ErrorCode.RequestTimeout); + } finally { + vi.useRealTimers(); + } + }); + + test('should yield error as terminal message for cancellation', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + const abortController = new AbortController(); + abortController.abort('User cancelled'); + + // Collect messages + const messages = await toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { + signal: abortController.signal + }) + ); + + // Verify error is terminal and last message + expect(messages.length).toBeGreaterThan(0); + const lastMessage = messages[messages.length - 1]; + assertErrorResponse(lastMessage); + expect(lastMessage.error).toBeDefined(); + expect(lastMessage.error.message).toContain('cancelled'); + }); + + test('should not yield any messages after error message', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + const messagesPromise = toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + ); + + // Simulate server error response + await new Promise(resolve => setTimeout(resolve, 10)); + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + error: { + code: ErrorCode.InternalError, + message: 'Test error' + } + }); + + // Collect messages + const messages = await messagesPromise; + + // Verify only one message (the error) was yielded + expect(messages).toHaveLength(1); + expect(messages[0].type).toBe('error'); + + // Try to send another message (should be ignored) + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + result: { + content: [{ type: 'text', text: 'should not appear' }] + } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify no additional messages were yielded + expect(messages).toHaveLength(1); + }); + + test('should yield error as terminal message for task failure', async () => { + const transport = new MockTransport(); + const mockTaskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + await protocol.connect(transport); + + const messagesPromise = toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + ); + + // Simulate task creation response + await new Promise(resolve => setTimeout(resolve, 10)); + const taskId = 'test-task-123'; + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + result: { + _meta: { + task: { + taskId, + status: 'working', + createdAt: new Date().toISOString(), + pollInterval: 100 + } + } + } + }); + + // Wait for task creation to be processed + await new Promise(resolve => setTimeout(resolve, 20)); + + // Update task to failed status + const failedTask = { + taskId, + status: 'failed' as const, + createdAt: new Date().toISOString(), + pollInterval: 100, + ttl: null, + statusMessage: 'Task failed' + }; + mockTaskStore.getTask.mockResolvedValue(failedTask); + + // Collect messages + const messages = await messagesPromise; + + // Verify error is terminal and last message + expect(messages.length).toBeGreaterThan(0); + const lastMessage = messages[messages.length - 1]; + assertErrorResponse(lastMessage); + expect(lastMessage.error).toBeDefined(); + }); + + test('should yield error as terminal message for network error', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + // Override send to simulate network error + transport.send = vi.fn().mockRejectedValue(new Error('Network error')); + + const messages = await toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + ); + + // Verify error is terminal and last message + expect(messages.length).toBeGreaterThan(0); + const lastMessage = messages[messages.length - 1]; + assertErrorResponse(lastMessage); + expect(lastMessage.error).toBeDefined(); + }); + + test('should ensure error is always the final message', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + const messagesPromise = toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + ); + + // Simulate server error response + await new Promise(resolve => setTimeout(resolve, 10)); + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + error: { + code: ErrorCode.InternalError, + message: 'Test error' + } + }); + + // Collect messages + const messages = await messagesPromise; + + // Verify error is the last message + expect(messages.length).toBeGreaterThan(0); + const lastMessage = messages[messages.length - 1]; + expect(lastMessage.type).toBe('error'); + + // Verify all messages before the last are not terminal + for (let i = 0; i < messages.length - 1; i++) { + expect(messages[i].type).not.toBe('error'); + expect(messages[i].type).not.toBe('result'); + } + }); + }); +}); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 27d76f568..2cda0ff08 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -2,6 +2,7 @@ import { ZodLiteral, ZodObject, ZodType, z } from 'zod'; import { CancelledNotificationSchema, ClientCapabilities, + CreateTaskResultSchema, ErrorCode, GetTaskRequest, GetTaskRequestSchema, @@ -44,8 +45,8 @@ import { } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; -import { PendingRequest } from './request.js'; import { isTerminal, TaskStore } from './task.js'; +import { ResponseMessage } from './responseMessage.js'; /** * Represents a message queued for side-channel delivery via tasks/result. @@ -454,9 +455,21 @@ export abstract class Protocol>( + async *requestStream>( request: SendRequestT, resultSchema: T, options?: RequestOptions - ): PendingRequest> { + ): AsyncGenerator>, void, void> { + const { task } = options ?? {}; + + // For non-task requests, just yield the result + if (!task) { + try { + const result = await this.request(request, resultSchema, options); + yield { type: 'result', result }; + } catch (error) { + yield { + type: 'error', + error: error instanceof McpError ? error : new McpError(ErrorCode.InternalError, String(error)) + }; + } + return; + } + + // For task-augmented requests, we need to poll for status + // First, make the request to create the task + let taskId: string | undefined; + try { + // Send the request and get the CreateTaskResult + const createResult = await this.request(request, CreateTaskResultSchema as unknown as T, options); + + // Extract taskId from the result + if ('task' in createResult && typeof createResult.task === 'object' && createResult.task && 'taskId' in createResult.task) { + taskId = (createResult.task as { taskId: string }).taskId; + yield { type: 'taskCreated', task: createResult.task as Task }; + } else { + throw new McpError(ErrorCode.InternalError, 'Task creation did not return a task'); + } + + // Poll for task completion + while (true) { + // Get current task status + const task = await this.getTask({ taskId }, options); + yield { type: 'taskStatus', task }; + + // Check if task is terminal + if (isTerminal(task.status)) { + if (task.status === 'completed') { + // Get the final result + const result = await this.getTaskResult({ taskId }, resultSchema, options); + yield { type: 'result', result }; + } else if (task.status === 'failed') { + yield { + type: 'error', + error: new McpError(ErrorCode.InternalError, `Task ${taskId} failed`) + }; + } else if (task.status === 'cancelled') { + yield { + type: 'error', + error: new McpError(ErrorCode.InternalError, `Task ${taskId} was cancelled`) + }; + } + return; + } + + // Wait before polling again + const pollInterval = task.pollInterval ?? 1000; + await new Promise(resolve => setTimeout(resolve, pollInterval)); + + // Check if cancelled + options?.signal?.throwIfAborted(); + } + } catch (error) { + yield { + type: 'error', + error: error instanceof McpError ? error : new McpError(ErrorCode.InternalError, String(error)) + }; + } + } + + /** + * Sends a request and waits for a response. + * + * Do not use this method to emit notifications! Use notification() instead. + */ + request>(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {}; // Send the request - const result = new Promise>((resolve, reject) => { + return new Promise>((resolve, reject) => { const earlyReject = (error: unknown) => { reject(error); }; @@ -1020,7 +1131,9 @@ export abstract class Protocol this._onerror(new Error(`Failed to send cancellation: ${error}`))); - reject(reason); + // Wrap the reason in an McpError if it isn't already + const error = reason instanceof McpError ? reason : new McpError(ErrorCode.RequestTimeout, String(reason)); + reject(error); }; this._responseHandlers.set(messageId, response => { @@ -1081,22 +1194,8 @@ export abstract class Protocol { - this._onerror( - new Error( - `Transport send failed for queued message (this is expected for unidirectional transports): ${error.message}` - ) - ); - }); - } catch (error) { - this._onerror( - new Error( - `Transport send failed synchronously for queued message (this is expected for unidirectional transports): ${error instanceof Error ? error.message : String(error)}` - ) - ); - } + // Don't send through transport - queued messages are delivered via tasks/result only + // This prevents duplicate delivery for bidirectional transports } else { // No related task - send through transport normally this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { @@ -1105,23 +1204,6 @@ export abstract class Protocol>, - result, - resultSchema, - undefined, - this._options?.defaultTaskPollInterval - ); - } - - /** - * Sends a request and waits for a response. - * - * Do not use this method to emit notifications! Use notification() instead. - */ - request>(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { - return this.beginRequest(request, resultSchema, options).result(); } /** @@ -1195,22 +1277,8 @@ export abstract class Protocol { - this._onerror( - new Error( - `Transport send failed for queued notification (this is expected for unidirectional transports): ${error.message}` - ) - ); - }); - } catch (error) { - this._onerror( - new Error( - `Transport send failed synchronously for queued notification (this is expected for unidirectional transports): ${error instanceof Error ? error.message : String(error)}` - ) - ); - } + // Don't send through transport - queued messages are delivered via tasks/result only + // This prevents duplicate delivery for bidirectional transports return; } @@ -1521,7 +1589,8 @@ export abstract class Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - - // Expose methods for testing - public mockGetTask = vi.fn(); - public mockGetTaskResult = vi.fn(); - - async getTask(params: { taskId: string }): Promise { - return this.mockGetTask(params); - } - - async getTaskResult(params: { taskId: string }, _resultSchema: ZodType): Promise { - return this.mockGetTaskResult(params, _resultSchema) as Promise; - } -} - -describe('PendingRequest', () => { - let protocol: MockProtocol; - const mockResultSchema = z.object({ result: z.string() }); - - beforeEach(() => { - protocol = new MockProtocol(); - }); - - describe('input_required status handling', () => { - it('should preemptively call tasks/result when input_required status is encountered', async () => { - // Setup: Create a task that transitions to input_required - const taskId = 'test-task-123'; - const expectedResult = { result: 'completed after input' }; - - // Mock getTask to return input_required status - protocol.mockGetTask.mockResolvedValueOnce({ - taskId, - status: 'input_required', - ttl: null, - createdAt: new Date().toISOString(), - pollInterval: 1000 - }); - - // Mock getTaskResult to return the final result - protocol.mockGetTaskResult.mockResolvedValueOnce(expectedResult); - - // Create a PendingRequest with a task ID - const resultHandle = Promise.resolve(expectedResult); - const pendingRequest = new PendingRequest(protocol, resultHandle, mockResultSchema, taskId, 5000); - - // Execute: Call result() which should trigger taskHandler - const result = await pendingRequest.result({ - onTaskCreated: vi.fn(), - onTaskStatus: vi.fn() - }); - - // Verify: getTask was called once - expect(protocol.mockGetTask).toHaveBeenCalledTimes(1); - expect(protocol.mockGetTask).toHaveBeenCalledWith({ taskId }); - - // Verify: getTaskResult was called immediately after detecting input_required - expect(protocol.mockGetTaskResult).toHaveBeenCalledTimes(1); - expect(protocol.mockGetTaskResult).toHaveBeenCalledWith({ taskId }, mockResultSchema); - - // Verify: Result is correct - expect(result).toEqual(expectedResult); - }); - - it('should call onTaskStatus before calling tasks/result for input_required', async () => { - const taskId = 'test-task-456'; - const expectedResult = { result: 'completed' }; - const onTaskStatus = vi.fn(); - - const inputRequiredTask: GetTaskResult = { - taskId, - status: 'input_required', - ttl: null, - createdAt: new Date().toISOString(), - pollInterval: 1000 - }; - - protocol.mockGetTask.mockResolvedValueOnce(inputRequiredTask); - protocol.mockGetTaskResult.mockResolvedValueOnce(expectedResult); - - const resultHandle = Promise.resolve(expectedResult); - const pendingRequest = new PendingRequest(protocol, resultHandle, mockResultSchema, taskId, 5000); - - await pendingRequest.result({ - onTaskCreated: vi.fn(), - onTaskStatus - }); - - // Verify: onTaskStatus was called with the input_required task - expect(onTaskStatus).toHaveBeenCalledWith(inputRequiredTask); - expect(onTaskStatus).toHaveBeenCalledBefore(protocol.mockGetTaskResult); - }); - - it('should not poll again after encountering input_required status', async () => { - const taskId = 'test-task-789'; - const expectedResult = { result: 'completed' }; - - protocol.mockGetTask.mockResolvedValueOnce({ - taskId, - status: 'input_required', - ttl: null, - createdAt: new Date().toISOString(), - pollInterval: 100 // Short interval to test that we don't wait - }); - - protocol.mockGetTaskResult.mockResolvedValueOnce(expectedResult); - - const resultHandle = Promise.resolve(expectedResult); - const pendingRequest = new PendingRequest(protocol, resultHandle, mockResultSchema, taskId, 5000); - - const startTime = Date.now(); - await pendingRequest.result({ - onTaskCreated: vi.fn(), - onTaskStatus: vi.fn() - }); - const endTime = Date.now(); - - // Verify: getTask was only called once (no polling) - expect(protocol.mockGetTask).toHaveBeenCalledTimes(1); - - // Verify: The operation completed quickly without waiting for pollInterval - expect(endTime - startTime).toBeLessThan(100); - }); - - it('should continue normal polling for working status before input_required', async () => { - const taskId = 'test-task-abc'; - const expectedResult = { result: 'completed' }; - - // First poll: working status - protocol.mockGetTask.mockResolvedValueOnce({ - taskId, - status: 'working', - ttl: null, - createdAt: new Date().toISOString(), - pollInterval: 10 - }); - - // Second poll: input_required status - protocol.mockGetTask.mockResolvedValueOnce({ - taskId, - status: 'input_required', - ttl: null, - createdAt: new Date().toISOString(), - pollInterval: 10 - }); - - protocol.mockGetTaskResult.mockResolvedValueOnce(expectedResult); - - const resultHandle = Promise.resolve(expectedResult); - const pendingRequest = new PendingRequest(protocol, resultHandle, mockResultSchema, taskId, 5000); - - await pendingRequest.result({ - onTaskCreated: vi.fn(), - onTaskStatus: vi.fn() - }); - - // Verify: getTask was called twice (once for working, once for input_required) - expect(protocol.mockGetTask).toHaveBeenCalledTimes(2); - - // Verify: getTaskResult was called after input_required was detected - expect(protocol.mockGetTaskResult).toHaveBeenCalledTimes(1); - }); - - it('should handle terminal status normally without input_required', async () => { - const taskId = 'test-task-def'; - const expectedResult = { result: 'completed' }; - - // Task is already completed - protocol.mockGetTask.mockResolvedValueOnce({ - taskId, - status: 'completed', - ttl: null, - createdAt: new Date().toISOString(), - pollInterval: 1000 - }); - - protocol.mockGetTaskResult.mockResolvedValueOnce(expectedResult); - - const resultHandle = Promise.resolve(expectedResult); - const pendingRequest = new PendingRequest(protocol, resultHandle, mockResultSchema, taskId, 5000); - - await pendingRequest.result({ - onTaskCreated: vi.fn(), - onTaskStatus: vi.fn() - }); - - // Verify: Normal flow - getTask once, then getTaskResult - expect(protocol.mockGetTask).toHaveBeenCalledTimes(1); - expect(protocol.mockGetTaskResult).toHaveBeenCalledTimes(1); - }); - }); - - describe('normal task polling', () => { - it('should poll until terminal status is reached', async () => { - const taskId = 'test-task-polling'; - const expectedResult = { result: 'completed' }; - - // First poll: working - protocol.mockGetTask.mockResolvedValueOnce({ - taskId, - status: 'working', - ttl: null, - createdAt: new Date().toISOString(), - pollInterval: 10 - }); - - // Second poll: still working - protocol.mockGetTask.mockResolvedValueOnce({ - taskId, - status: 'working', - ttl: null, - createdAt: new Date().toISOString(), - pollInterval: 10 - }); - - // Third poll: completed - protocol.mockGetTask.mockResolvedValueOnce({ - taskId, - status: 'completed', - ttl: null, - createdAt: new Date().toISOString(), - pollInterval: 10 - }); - - protocol.mockGetTaskResult.mockResolvedValueOnce(expectedResult); - - const resultHandle = Promise.resolve(expectedResult); - const pendingRequest = new PendingRequest(protocol, resultHandle, mockResultSchema, taskId, 5000); - - await pendingRequest.result({ - onTaskCreated: vi.fn(), - onTaskStatus: vi.fn() - }); - - // Verify: getTask was called three times - expect(protocol.mockGetTask).toHaveBeenCalledTimes(3); - - // Verify: getTaskResult was called once after terminal status - expect(protocol.mockGetTaskResult).toHaveBeenCalledTimes(1); - }); - }); -}); diff --git a/src/shared/request.ts b/src/shared/request.ts deleted file mode 100644 index 7333b57d6..000000000 --- a/src/shared/request.ts +++ /dev/null @@ -1,87 +0,0 @@ -import { ZodType } from 'zod'; -import { Protocol } from './protocol.js'; -import { Request, Notification, Result, Task, GetTaskResult } from '../types.js'; -import { isTerminal } from './task.js'; - -const DEFAULT_TASK_POLLING_INTERVAL = 5000; - -const DEFAULT_HANDLER = () => Promise.resolve(); - -export interface TaskHandlerOptions { - onTaskCreated: () => Promise | void; - onTaskStatus: (task: Task) => Promise | void; -} - -export class PendingRequest { - constructor( - readonly protocol: Protocol, - readonly resultHandle: Promise, - readonly resultSchema: ZodType, - readonly taskId?: string, - readonly defaultTaskPollInterval?: number - ) {} - - /** - * Waits for a result, calling onTaskStatus if provided and a task was created. - */ - async result(options?: Partial): Promise { - const { onTaskCreated = DEFAULT_HANDLER, onTaskStatus = DEFAULT_HANDLER } = options ?? {}; - - if (!this.taskId) { - // No task ID provided, just block for the result - return await this.resultHandle; - } - - // For task-based requests, start task polling and race with direct result - return Promise.allSettled([ - (async () => { - // Call onTaskCreated immediately since task is created synchronously by tool implementor - await onTaskCreated(); - - // Start task polling - return await this.taskHandler(this.taskId!, { - onTaskCreated, - onTaskStatus - }); - })(), - this.resultHandle - ]).then(([task, result]) => { - if (task.status === 'fulfilled') { - return task.value; - } else if (result.status === 'fulfilled') { - return result.value; - } - - // Both failed - prefer to throw the result error since it's usually more meaningful - // (e.g., timeout, connection error, etc.) than the task polling failure - throw result.reason; - }); - } - - /** - * Encapsulates polling for a result, calling onTaskStatus after querying the task. - */ - private async taskHandler(taskId: string, { onTaskStatus }: TaskHandlerOptions): Promise { - // Poll for completion - let task: GetTaskResult; - do { - task = await this.protocol.getTask({ taskId: taskId }); - await onTaskStatus(task); - - // Handle input_required status: preemptively call tasks/result instead of continuing to poll - // This allows the receiver to block and wait for user input before returning the result - if (task.status === 'input_required') { - return await this.protocol.getTaskResult({ taskId: taskId }, this.resultSchema); - } - - if (!isTerminal(task.status)) { - await new Promise(resolve => - setTimeout(resolve, task.pollInterval ?? this.defaultTaskPollInterval ?? DEFAULT_TASK_POLLING_INTERVAL) - ); - } - } while (!isTerminal(task.status)); - - // Process result - return await this.protocol.getTaskResult({ taskId: taskId }, this.resultSchema); - } -} diff --git a/src/shared/responseMessage.ts b/src/shared/responseMessage.ts new file mode 100644 index 000000000..c7584c0cd --- /dev/null +++ b/src/shared/responseMessage.ts @@ -0,0 +1,47 @@ +import { Result, Task, McpError } from '../types.js'; + +/** + * Base message type + */ +export interface BaseResponseMessage { + type: string; +} + +/** + * Task status update message + */ +export interface TaskStatusMessage extends BaseResponseMessage { + type: 'taskStatus'; + task: Task; +} + +/** + * Task created message (first message for task-augmented requests) + */ +export interface TaskCreatedMessage extends BaseResponseMessage { + type: 'taskCreated'; + task: Task; +} + +/** + * Final result message (terminal) + */ +export interface ResultMessage extends BaseResponseMessage { + type: 'result'; + result: T; +} + +/** + * Error message (terminal) + */ +export interface ErrorMessage extends BaseResponseMessage { + type: 'error'; + error: McpError; +} + +/** + * Union type representing all possible messages that can be yielded during request processing. + * Note: Progress notifications are handled through the existing onprogress callback mechanism. + * Side-channeled messages (server requests/notifications) are handled through registered handlers. + */ +export type ResponseMessage = TaskStatusMessage | TaskCreatedMessage | ResultMessage | ErrorMessage; From 57ab83b0f223e7f6f19fbf33c1ea2f4361da3ce2 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 20 Nov 2025 14:30:47 -0800 Subject: [PATCH 54/84] Remove unneeded casts in requestStream() --- src/shared/protocol.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 2cda0ff08..d1c6a4473 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -989,12 +989,12 @@ export abstract class Protocol setTimeout(resolve, pollInterval)); // Check if cancelled From a279a45a580e50a38bf0231c8e2adda6d6ea1df3 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 20 Nov 2025 15:02:56 -0800 Subject: [PATCH 55/84] Reuse callToolStream inside of callTool --- src/client/index.test.ts | 404 +++++++++++++++++----------------- src/client/index.ts | 222 +++++++++---------- src/server/index.test.ts | 290 ++++++++++++------------ src/shared/protocol.test.ts | 13 +- src/shared/responseMessage.ts | 23 ++ 5 files changed, 476 insertions(+), 476 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 8eaf0ccc5..888153d5d 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -22,6 +22,7 @@ import { } from '../types.js'; import { Transport } from '../shared/transport.js'; import { Server } from '../server/index.js'; +import { McpServer } from '../server/mcp.js'; import { InMemoryTransport } from '../inMemory.js'; import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; @@ -1343,14 +1344,13 @@ describe('Task-based execution', () => { }); test('should create task on server via tool call', async () => { - const server = new Server( + const server = new McpServer( { name: 'test-server', version: '1.0.0' }, { capabilities: { - tools: {}, tasks: { requests: { tools: { @@ -1363,45 +1363,42 @@ describe('Task-based execution', () => { } ); - server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - let taskId: string | undefined; - - // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; - } + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + { method: 'tools/call', params: { name: 'test-tool', arguments: {} } } + ); - if (request.params.name === 'test-tool') { - const result = { - content: [{ type: 'text', text: 'Tool executed successfully!' }] - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); - } - return result; - } - throw new Error('Unknown tool'); - }); + const result = { + content: [{ type: 'text', text: 'Tool executed successfully!' }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {} + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; } - ] - })); + } + ); const client = new Client({ name: 'test-client', @@ -1427,14 +1424,13 @@ describe('Task-based execution', () => { }); test('should query task status from server using getTask', async () => { - const server = new Server( + const server = new McpServer( { name: 'test-server', version: '1.0.0' }, { capabilities: { - tools: {}, tasks: { requests: { tools: { @@ -1447,45 +1443,42 @@ describe('Task-based execution', () => { } ); - server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - let taskId: string | undefined; - - // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; - } + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + { method: 'tools/call', params: { name: 'test-tool', arguments: {} } } + ); - if (request.params.name === 'test-tool') { - const result = { - content: [{ type: 'text', text: 'Success!' }] - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); - } - return result; - } - throw new Error('Unknown tool'); - }); + const result = { + content: [{ type: 'text', text: 'Success!' }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {} + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; } - ] - })); + } + ); const client = new Client({ name: 'test-client', @@ -1511,14 +1504,13 @@ describe('Task-based execution', () => { }); test('should query task result from server using getTaskResult', async () => { - const server = new Server( + const server = new McpServer( { name: 'test-server', version: '1.0.0' }, { capabilities: { - tools: {}, tasks: { requests: { tools: { @@ -1532,45 +1524,42 @@ describe('Task-based execution', () => { } ); - server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - let taskId: string | undefined; - - // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; - } + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + { method: 'tools/call', params: { name: 'test-tool', arguments: {} } } + ); - if (request.params.name === 'test-tool') { - const result = { - content: [{ type: 'text', text: 'Result data!' }] - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); - } - return result; - } - throw new Error('Unknown tool'); - }); + const result = { + content: [{ type: 'text', text: 'Result data!' }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {} + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; } - ] - })); + } + ); const client = new Client({ name: 'test-client', @@ -1595,14 +1584,13 @@ describe('Task-based execution', () => { }); test('should query task list from server using listTasks', async () => { - const server = new Server( + const server = new McpServer( { name: 'test-server', version: '1.0.0' }, { capabilities: { - tools: {}, tasks: { requests: { tools: { @@ -1615,44 +1603,42 @@ describe('Task-based execution', () => { } ); - server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - let taskId: string | undefined; + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + { method: 'tools/call', params: { name: 'test-tool', arguments: {} } } + ); - // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; - } - if (request.params.name === 'test-tool') { - const result = { - content: [{ type: 'text', text: 'Success!' }] - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); - } - return result; - } - throw new Error('Unknown tool'); - }); + const result = { + content: [{ type: 'text', text: 'Success!' }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {} + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; } - ] - })); + } + ); const client = new Client({ name: 'test-client', @@ -2119,14 +2105,13 @@ describe('Task-based execution', () => { test('should list tasks from server with pagination', async () => { const serverTaskStore = new InMemoryTaskStore(); - const server = new Server( + const server = new McpServer( { name: 'test-server', version: '1.0.0' }, { capabilities: { - tools: {}, tasks: { requests: { tools: { @@ -2139,44 +2124,44 @@ describe('Task-based execution', () => { } ); - server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - let taskId: string | undefined; - - // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; - } - if (request.params.name === 'test-tool') { - const result = { - content: [{ type: 'text', text: `Result for ${request.params.arguments?.id || 'unknown'}` }] - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: { + id: z.string().optional() } - return result; - } - throw new Error('Unknown tool'); - }); + }, + { + async createTask({ id }, extra) { + const task = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + { method: 'tools/call', params: { name: 'test-tool', arguments: { id } } } + ); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {} + const result = { + content: [{ type: 'text', text: `Result for ${id || 'unknown'}` }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; } - ] - })); + } + ); const client = new Client( { @@ -2394,14 +2379,13 @@ describe('Task-based execution', () => { test('should respect server task capabilities', async () => { const serverTaskStore = new InMemoryTaskStore(); - const server = new Server( + const server = new McpServer( { name: 'test-server', version: '1.0.0' }, { capabilities: { - tools: {}, tasks: { requests: { tools: { @@ -2414,22 +2398,42 @@ test('should respect server task capabilities', async () => { } ); - server.setRequestHandler(CallToolRequestSchema, async () => ({ - content: [{ type: 'text', text: 'Success!' }] - })); + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + { method: 'tools/call', params: { name: 'test-tool', arguments: {} } } + ); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {} + const result = { + content: [{ type: 'text', text: 'Success!' }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; } - ] - })); + } + ); const client = new Client( { @@ -2446,7 +2450,9 @@ test('should respect server task capabilities', async () => { // Server supports task creation for tools/call expect(client.getServerCapabilities()).toEqual({ - tools: {}, + tools: { + listChanged: true + }, tasks: { requests: { tools: { diff --git a/src/client/index.ts b/src/client/index.ts index cbbf41a9f..44e1defe4 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,6 +1,6 @@ import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; import type { Transport } from '../shared/transport.js'; -import { ResponseMessage } from '../shared/responseMessage.js'; +import { ResponseMessage, takeResult } from '../shared/responseMessage.js'; import { type CallToolRequest, @@ -161,6 +161,7 @@ export class Client< private _instructions?: string; private _jsonSchemaValidator: jsonSchemaValidator; private _cachedToolOutputValidators: Map> = new Map(); + private _cachedKnownTaskTools: Set = new Set(); /** * Initializes this client with the given name and version information. @@ -533,126 +534,12 @@ export class Client< * * For task-based execution with streaming behavior, use callToolStream() instead. */ - async callTool( + async callTool( params: CallToolRequest['params'], - resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, - options?: RequestOptions - ) { - const result = await this.request({ method: 'tools/call', params }, resultSchema, options); - - // Check if the tool has an outputSchema - const validator = this.getToolOutputValidator(params.name); - if (validator) { - // If tool has outputSchema, it MUST return structuredContent (unless it's an error) - if (!result.structuredContent && !result.isError) { - throw new McpError( - ErrorCode.InvalidRequest, - `Tool ${params.name} has an output schema but did not return structured content` - ); - } - - // Only validate structured content if present (not when there's an error) - if (result.structuredContent) { - try { - // Validate the structured content against the schema - const validationResult = validator(result.structuredContent); - - if (!validationResult.valid) { - throw new McpError( - ErrorCode.InvalidParams, - `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` - ); - } - } catch (error) { - if (error instanceof McpError) { - throw error; - } - throw new McpError( - ErrorCode.InvalidParams, - `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` - ); - } - } - } - - return result; - } - - /** - * Cache validators for tool output schemas. - * Called after listTools() to pre-compile validators for better performance. - */ - private cacheToolOutputSchemas(tools: Tool[]): void { - this._cachedToolOutputValidators.clear(); - - for (const tool of tools) { - // If the tool has an outputSchema, create and cache the validator - if (tool.outputSchema) { - const toolValidator = this._jsonSchemaValidator.getValidator(tool.outputSchema as JsonSchemaType); - this._cachedToolOutputValidators.set(tool.name, toolValidator); - } - } - } - - /** - * Get cached validator for a tool - */ - private getToolOutputValidator(toolName: string): JsonSchemaValidator | undefined { - return this._cachedToolOutputValidators.get(toolName); - } - - async listTools(params?: ListToolsRequest['params'], options?: RequestOptions) { - const result = await this.request({ method: 'tools/list', params }, ListToolsResultSchema, options); - - // Cache the tools and their output schemas for future validation - this.cacheToolOutputSchemas(result.tools); - - return result; - } - - async sendRootsListChanged() { - return this.notification({ method: 'notifications/roots/list_changed' }); - } - - /** - * Sends a request and returns an AsyncGenerator that yields response messages. - * The generator is guaranteed to end with either a 'result' or 'error' message. - * - * This method provides streaming access to request processing, allowing you to - * observe intermediate task status updates for task-augmented requests. - * - * @example - * ```typescript - * const stream = client.requestStream(request, resultSchema, options); - * for await (const message of stream) { - * switch (message.type) { - * case 'taskCreated': - * console.log('Task created:', message.task.taskId); - * break; - * case 'taskStatus': - * console.log('Task status:', message.task.status); - * break; - * case 'result': - * console.log('Final result:', message.result); - * break; - * case 'error': - * console.error('Error:', message.error); - * break; - * } - * } - * ``` - * - * @param request - The request to send - * @param resultSchema - Zod schema for validating the result - * @param options - Optional request options (timeout, signal, task creation params, etc.) - * @returns AsyncGenerator that yields ResponseMessage objects - */ - requestStream>( - request: ClientRequest | RequestT, - resultSchema: T, + resultSchema: T = CallToolResultSchema as T, options?: RequestOptions - ): AsyncGenerator>, void, void> { - return super.requestStream(request, resultSchema, options); + ): Promise> { + return await takeResult(this.callToolStream(params, resultSchema, options)); } /** @@ -700,8 +587,9 @@ export class Client< // Add task creation parameters if server supports it and not explicitly provided const optionsWithTask = { ...options, - // We check the server capabilities in auto-assignment, but assume the caller knows what they're doing if they pass this explicitly - task: options?.task ?? (this._serverCapabilities?.tasks?.requests?.tools?.call ? {} : undefined) + // We check if the tool is known to be a task during auto-configuration, but assume + // the caller knows what they're doing if they pass this explicitly + task: options?.task ?? (this.isToolTask(params.name) ? {} : undefined) }; const stream = this.requestStream({ method: 'tools/call', params }, resultSchema, optionsWithTask); @@ -764,4 +652,96 @@ export class Client< yield message; } } + + private isToolTask(toolName: string): boolean { + if (!this._serverCapabilities?.tasks?.requests?.tools?.call) { + return false; + } + + return this._cachedKnownTaskTools.has(toolName); + } + + /** + * Cache validators for tool output schemas. + * Called after listTools() to pre-compile validators for better performance. + */ + private cacheToolMetadata(tools: Tool[]): void { + this._cachedToolOutputValidators.clear(); + this._cachedKnownTaskTools.clear(); + + for (const tool of tools) { + // If the tool has an outputSchema, create and cache the validator + if (tool.outputSchema) { + const toolValidator = this._jsonSchemaValidator.getValidator(tool.outputSchema as JsonSchemaType); + this._cachedToolOutputValidators.set(tool.name, toolValidator); + } + + // If the tool supports task-based execution, cache that information + const taskHint = tool.annotations?.taskHint; + if (taskHint === 'always' || taskHint === 'optional') { + this._cachedKnownTaskTools.add(tool.name); + } + } + } + + /** + * Get cached validator for a tool + */ + private getToolOutputValidator(toolName: string): JsonSchemaValidator | undefined { + return this._cachedToolOutputValidators.get(toolName); + } + + async listTools(params?: ListToolsRequest['params'], options?: RequestOptions) { + const result = await this.request({ method: 'tools/list', params }, ListToolsResultSchema, options); + + // Cache the tools and their output schemas for future validation + this.cacheToolMetadata(result.tools); + + return result; + } + + async sendRootsListChanged() { + return this.notification({ method: 'notifications/roots/list_changed' }); + } + + /** + * Sends a request and returns an AsyncGenerator that yields response messages. + * The generator is guaranteed to end with either a 'result' or 'error' message. + * + * This method provides streaming access to request processing, allowing you to + * observe intermediate task status updates for task-augmented requests. + * + * @example + * ```typescript + * const stream = client.requestStream(request, resultSchema, options); + * for await (const message of stream) { + * switch (message.type) { + * case 'taskCreated': + * console.log('Task created:', message.task.taskId); + * break; + * case 'taskStatus': + * console.log('Task status:', message.task.status); + * break; + * case 'result': + * console.log('Final result:', message.result); + * break; + * case 'error': + * console.error('Error:', message.error); + * break; + * } + * } + * ``` + * + * @param request - The request to send + * @param resultSchema - Zod schema for validating the result + * @param options - Optional request options (timeout, signal, task creation params, etc.) + * @returns AsyncGenerator that yields ResponseMessage objects + */ + requestStream>( + request: ClientRequest | RequestT, + resultSchema: T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + return super.requestStream(request, resultSchema, options); + } } diff --git a/src/server/index.test.ts b/src/server/index.test.ts index a48f2d639..107295ba9 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -21,6 +21,7 @@ import { SUPPORTED_PROTOCOL_VERSIONS } from '../types.js'; import { Server } from './index.js'; +import { McpServer } from './mcp.js'; import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; import { CallToolRequestSchema, CallToolResultSchema } from '../types.js'; @@ -958,14 +959,13 @@ describe('Task-based execution', () => { test('server with TaskStore should handle task-based tool execution', async () => { const taskStore = new InMemoryTaskStore(); - const server = new Server( + const server = new McpServer( { name: 'test-server', version: '1.0.0' }, { capabilities: { - tools: {}, tasks: { requests: { tools: { @@ -978,48 +978,47 @@ describe('Task-based execution', () => { } ); - // Set up a tool handler that simulates some work - server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - let taskId: string | undefined; - - // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; - } + // Register a tool using registerToolTask + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + { method: 'tools/call', params: { name: 'test-tool', arguments: {} } } + ); - if (request.params.name === 'test-tool') { - // Simulate some async work - await new Promise(resolve => setTimeout(resolve, 10)); - const result = { - content: [{ type: 'text', text: 'Tool executed successfully!' }] - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); - } - return result; - } - throw new Error('Unknown tool'); - }); + // Simulate some async work + (async () => { + await new Promise(resolve => setTimeout(resolve, 10)); + const result = { + content: [{ type: 'text', text: 'Tool executed successfully!' }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + })(); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {} + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; } - ] - })); + } + ); const client = new Client( { @@ -1043,7 +1042,7 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // Use beginCallTool to create a task + // Use callTool to create a task await client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 @@ -1051,6 +1050,7 @@ describe('Task-based execution', () => { }); // Wait for the task to complete + await new Promise(resolve => setTimeout(resolve, 50)); // Get the task ID from the task list since it's generated automatically const taskList = await client.listTasks(); @@ -1136,14 +1136,13 @@ describe('Task-based execution', () => { test('should automatically attach related-task metadata to nested requests during tool execution', async () => { const taskStore = new InMemoryTaskStore(); - const server = new Server( + const server = new McpServer( { name: 'test-server', version: '1.0.0' }, { capabilities: { - tools: {}, tasks: { requests: { tools: { @@ -1205,69 +1204,69 @@ describe('Task-based execution', () => { }; }); - // Set up server tool that makes a nested elicitation request - server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - let taskId: string | undefined; - - // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; - } - - if (request.params.name === 'collect-info') { - // During tool execution, make a nested request to the client using extra.sendRequest - const elicitResult = await extra.sendRequest( - { - method: 'elicitation/create', - params: { - message: 'Please provide your username', - requestedSchema: { - type: 'object', - properties: { - username: { type: 'string' } - }, - required: ['username'] - } - } - }, - ElicitResultSchema - ); - - const result = { - content: [ + // Register a tool using registerToolTask that makes a nested elicitation request + server.registerToolTask( + 'collect-info', + { + description: 'Collects user info via elicitation', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask( { - type: 'text', - text: `Collected username: ${elicitResult.action === 'accept' && elicitResult.content ? (elicitResult.content as Record).username : 'none'}` - } - ] - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); - } - return result; - } - throw new Error('Unknown tool'); - }); + ttl: extra.taskRequestedTtl + }, + extra.requestId, + { method: 'tools/call', params: { name: 'collect-info', arguments: {} } } + ); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'collect-info', - description: 'Collects user info via elicitation', - inputSchema: { - type: 'object', - properties: {} + // Perform async work that makes a nested request + (async () => { + // During tool execution, make a nested request to the client using extra.sendRequest + const elicitResult = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' } + }, + required: ['username'] + } + } + }, + ElicitResultSchema + ); + + const result = { + content: [ + { + type: 'text', + text: `Collected username: ${elicitResult.action === 'accept' && elicitResult.content ? (elicitResult.content as Record).username : 'none'}` + } + ] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; } - ] - })); + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1281,6 +1280,7 @@ describe('Task-based execution', () => { }); // Wait for completion + await new Promise(resolve => setTimeout(resolve, 50)); // Verify the nested elicitation request was made (related-task metadata is no longer automatically attached) expect(capturedElicitRequest).toBeDefined(); @@ -1694,14 +1694,13 @@ describe('Task-based execution', () => { test('should handle multiple concurrent task-based tool calls', async () => { const taskStore = new InMemoryTaskStore(); - const server = new Server( + const server = new McpServer( { name: 'test-server', version: '1.0.0' }, { capabilities: { - tools: {}, tasks: { requests: { tools: { @@ -1714,50 +1713,50 @@ describe('Task-based execution', () => { } ); - // Set up a tool handler with variable delay - server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - let taskId: string | undefined; - - // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); - taskId = createdTask.taskId; - } - if (request.params.name === 'async-tool') { - const delay = (request.params.arguments?.delay as number) || 10; - await new Promise(resolve => setTimeout(resolve, delay)); - const result = { - content: [{ type: 'text', text: `Completed task ${request.params.arguments?.taskNum || 'unknown'}` }] - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); + // Register a tool using registerToolTask with variable delay + server.registerToolTask( + 'async-tool', + { + description: 'An async test tool', + inputSchema: { + delay: z.number().optional().default(10), + taskNum: z.number().optional() } - return result; - } - throw new Error('Unknown tool'); - }); + }, + { + async createTask({ delay, taskNum }, extra) { + const task = await extra.taskStore.createTask( + { + ttl: extra.taskRequestedTtl + }, + extra.requestId, + { method: 'tools/call', params: { name: 'async-tool', arguments: { delay, taskNum } } } + ); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'async-tool', - description: 'An async test tool', - inputSchema: { - type: 'object', - properties: { - delay: { type: 'number' }, - taskNum: { type: 'number' } - } + // Simulate async work + (async () => { + await new Promise(resolve => setTimeout(resolve, delay)); + const result = { + content: [{ type: 'text', text: `Completed task ${taskNum || 'unknown'}` }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; } - ] - })); + } + ); const client = new Client( { @@ -1791,6 +1790,9 @@ describe('Task-based execution', () => { // Wait for all tasks to complete await Promise.all(pendingRequests); + // Wait a bit more to ensure all tasks are completed + await new Promise(resolve => setTimeout(resolve, 50)); + // Get all task IDs from the task list const taskList = await client.listTasks(); expect(taskList.tasks.length).toBeGreaterThanOrEqual(4); diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index ebbd32def..5dffe6cc8 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -19,7 +19,7 @@ import { Transport, TransportSendOptions } from './transport.js'; import { TaskStore } from './task.js'; import { MockInstance, vi } from 'vitest'; import { JSONRPCResponse, JSONRPCRequest, JSONRPCError } from '../types.js'; -import { ErrorMessage, ResponseMessage } from './responseMessage.js'; +import { ErrorMessage, ResponseMessage, toArrayAsync } from './responseMessage.js'; // Type helper for accessing private Protocol properties in tests interface TestProtocol { @@ -118,17 +118,6 @@ function createLatch() { }; } -type AsyncGeneratorValue = T extends AsyncGenerator ? U : never; - -async function toArrayAsync>(it: T): Promise[]> { - const arr: AsyncGeneratorValue[] = []; - for await (const o of it) { - arr.push(o as AsyncGeneratorValue); - } - - return arr; -} - function assertErrorResponse(o: ResponseMessage): asserts o is ErrorMessage { expect(o.type).toBe('error'); } diff --git a/src/shared/responseMessage.ts b/src/shared/responseMessage.ts index c7584c0cd..6fefcf1f6 100644 --- a/src/shared/responseMessage.ts +++ b/src/shared/responseMessage.ts @@ -45,3 +45,26 @@ export interface ErrorMessage extends BaseResponseMessage { * Side-channeled messages (server requests/notifications) are handled through registered handlers. */ export type ResponseMessage = TaskStatusMessage | TaskCreatedMessage | ResultMessage | ErrorMessage; + +export type AsyncGeneratorValue = T extends AsyncGenerator ? U : never; + +export async function toArrayAsync>(it: T): Promise[]> { + const arr: AsyncGeneratorValue[] = []; + for await (const o of it) { + arr.push(o as AsyncGeneratorValue); + } + + return arr; +} + +export async function takeResult>>(it: U): Promise { + for await (const o of it) { + if (o.type === 'result') { + return o.result; + } else if (o.type === 'error') { + throw o.error; + } + } + + throw new Error('No result in stream.'); +} From e2f6e899d444337ec71c50ce31f3a45d758a5820 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 20 Nov 2025 16:28:40 -0800 Subject: [PATCH 56/84] Add missing schema --- src/client/index.test.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 8703b78a2..b3f13b2bc 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -2508,7 +2508,9 @@ describe('Task-based execution', () => { 'test-tool', { description: 'A test tool', - inputSchema: {} + inputSchema: { + id: z4.string() + } }, { async createTask({ id }, extra) { From b9c3054a83f0eda83d0bf54b684726fee049954e Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 20 Nov 2025 16:30:53 -0800 Subject: [PATCH 57/84] Remove now-unused isomorphic UUID dependency --- package-lock.json | 22 ---------------------- package.json | 1 - 2 files changed, 23 deletions(-) diff --git a/package-lock.json b/package-lock.json index dc109647e..d6791ef06 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,7 +9,6 @@ "version": "1.23.0-beta.0", "license": "MIT", "dependencies": { - "@lukeed/uuid": "^2.0.1", "ajv": "^8.17.1", "ajv-formats": "^3.0.1", "content-type": "^1.0.5", @@ -692,27 +691,6 @@ "dev": true, "license": "MIT" }, - "node_modules/@lukeed/csprng": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@lukeed/csprng/-/csprng-1.1.0.tgz", - "integrity": "sha512-Z7C/xXCiGWsg0KuKsHTKJxbWhpI3Vs5GwLfOean7MGyVFGqdRgBbAjOCh6u4bbjPc/8MJ2pZmK/0DLdCbivLDA==", - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/@lukeed/uuid": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/@lukeed/uuid/-/uuid-2.0.1.tgz", - "integrity": "sha512-qC72D4+CDdjGqJvkFMMEAtancHUQ7/d/tAiHf64z8MopFDmcrtbcJuerDtFceuAfQJ2pDSfCKCtbqoGBNnwg0w==", - "license": "MIT", - "dependencies": { - "@lukeed/csprng": "^1.1.0" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/@noble/hashes": { "version": "1.8.0", "resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.8.0.tgz", diff --git a/package.json b/package.json index ba48d09d0..560c4bcb2 100644 --- a/package.json +++ b/package.json @@ -78,7 +78,6 @@ "client": "tsx scripts/cli.ts client" }, "dependencies": { - "@lukeed/uuid": "^2.0.1", "ajv": "^8.17.1", "ajv-formats": "^3.0.1", "content-type": "^1.0.5", From cd89e769c779768a24d6db9898f8520e033e73e7 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 20 Nov 2025 16:38:42 -0800 Subject: [PATCH 58/84] Use same polling interval in waitForTaskUpdate as caller would use --- src/shared/protocol.ts | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 198faea8e..8e5c216e1 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -1513,6 +1513,17 @@ export abstract class Protocol { + // Get the task's poll interval, falling back to default + let interval = this._options?.defaultTaskPollInterval ?? 1000; + try { + const task = await this._taskStore?.getTask(taskId); + if (task?.pollInterval) { + interval = task.pollInterval; + } + } catch { + // Use default interval if task lookup fails + } + return new Promise((resolve, reject) => { if (signal.aborted) { reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); @@ -1543,7 +1554,7 @@ export abstract class Protocol clearInterval(pollInterval); @@ -1598,7 +1609,6 @@ export abstract class Protocol Date: Thu, 20 Nov 2025 19:56:52 -0800 Subject: [PATCH 59/84] Make queue implementation swappable --- src/client/index.test.ts | 2 +- src/examples/server/simpleStreamableHttp.ts | 5 +- src/examples/shared/inMemoryTaskStore.ts | 81 +++- src/integration-tests/taskLifecycle.test.ts | 5 +- src/server/index.test.ts | 2 +- src/shared/protocol.test.ts | 513 +++++++++----------- src/shared/protocol.ts | 203 +++----- src/shared/task-listing.test.ts | 5 +- src/shared/task.ts | 66 ++- 9 files changed, 442 insertions(+), 440 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index b3f13b2bc..7ab683b72 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -25,7 +25,7 @@ import { Transport } from '../shared/transport.js'; import { Server } from '../server/index.js'; import { McpServer } from '../server/mcp.js'; import { InMemoryTransport } from '../inMemory.js'; -import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; import * as z3 from 'zod/v3'; import * as z4 from 'zod/v4'; diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 1948abb82..796e9d13f 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -15,7 +15,7 @@ import { ResourceLink } from '../../types.js'; import { InMemoryEventStore } from '../shared/inMemoryEventStore.js'; -import { InMemoryTaskStore } from '../shared/inMemoryTaskStore.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../shared/inMemoryTaskStore.js'; import { setupAuthServer } from './demoInMemoryOAuthProvider.js'; import { OAuthMetadata } from '../../shared/auth.js'; import { checkResourceAllowed } from '../../shared/auth-utils.js'; @@ -40,7 +40,8 @@ const getServer = () => { }, { capabilities: { logging: {}, tasks: { requests: { tools: { call: {} } } } }, - taskStore // Enable task support + taskStore, // Enable task support + taskMessageQueue: new InMemoryTaskMessageQueue() } ); diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index 3414d2bbd..ad0ab4de6 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -1,5 +1,5 @@ import { Task, TaskCreationParams, Request, RequestId, Result } from '../../types.js'; -import { TaskStore, isTerminal } from '../../shared/task.js'; +import { TaskStore, isTerminal, TaskMessageQueue, QueuedMessage } from '../../shared/task.js'; import { randomBytes } from 'crypto'; interface StoredTask { @@ -198,3 +198,82 @@ export class InMemoryTaskStore implements TaskStore { return Array.from(this.tasks.values()).map(stored => ({ ...stored.task })); } } + +/** + * A simple in-memory implementation of TaskMessageQueue for demonstration purposes. + * + * This implementation stores messages in memory, organized by task ID and optional session ID. + * Messages are stored in FIFO queues per task. + * + * Note: This is not suitable for production use in distributed systems. + * For production, consider implementing TaskMessageQueue with Redis or other distributed queues. + */ +export class InMemoryTaskMessageQueue implements TaskMessageQueue { + private queues = new Map(); + + /** + * Generates a queue key from taskId. + * SessionId is intentionally ignored because taskIds are globally unique + * and tasks need to be accessible across HTTP requests/sessions. + */ + private getQueueKey(taskId: string, _sessionId?: string): string { + return taskId; + } + + /** + * Gets or creates a queue for the given task and session. + */ + private getQueue(taskId: string, sessionId?: string): QueuedMessage[] { + const key = this.getQueueKey(taskId, sessionId); + let queue = this.queues.get(key); + if (!queue) { + queue = []; + this.queues.set(key, queue); + } + return queue; + } + + /** + * Adds a message to the end of the queue for a specific task. + * Atomically checks queue size and throws if maxSize would be exceeded. + * @param taskId The task identifier + * @param message The message to enqueue + * @param sessionId Optional session ID for binding the operation to a specific session + * @param maxSize Optional maximum queue size - if specified and queue is full, throws an error + * @throws Error if maxSize is specified and would be exceeded + */ + async enqueue(taskId: string, message: QueuedMessage, sessionId?: string, maxSize?: number): Promise { + const queue = this.getQueue(taskId, sessionId); + + // Atomically check size and enqueue + if (maxSize !== undefined && queue.length >= maxSize) { + throw new Error(`Task message queue overflow: queue size (${queue.length}) exceeds maximum (${maxSize})`); + } + + queue.push(message); + } + + /** + * Removes and returns the first message from the queue for a specific task. + * @param taskId The task identifier + * @param sessionId Optional session ID for binding the query to a specific session + * @returns The first message, or undefined if the queue is empty + */ + async dequeue(taskId: string, sessionId?: string): Promise { + const queue = this.getQueue(taskId, sessionId); + return queue.shift(); + } + + /** + * Removes and returns all messages from the queue for a specific task. + * @param taskId The task identifier + * @param sessionId Optional session ID for binding the query to a specific session + * @returns Array of all messages that were in the queue + */ + async dequeueAll(taskId: string, sessionId?: string): Promise { + const key = this.getQueueKey(taskId, sessionId); + const queue = this.queues.get(key) ?? []; + this.queues.delete(key); + return queue; + } +} diff --git a/src/integration-tests/taskLifecycle.test.ts b/src/integration-tests/taskLifecycle.test.ts index 0c2ca1c5a..27b64b235 100644 --- a/src/integration-tests/taskLifecycle.test.ts +++ b/src/integration-tests/taskLifecycle.test.ts @@ -7,7 +7,7 @@ import { McpServer } from '../server/mcp.js'; import { StreamableHTTPServerTransport } from '../server/streamableHttp.js'; import { CallToolResultSchema, CreateTaskResultSchema, ElicitRequestSchema, ElicitResultSchema, TaskSchema } from '../types.js'; import { z } from 'zod'; -import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; import type { TaskRequestOptions } from '../shared/protocol.js'; describe('Task Lifecycle Integration Tests', () => { @@ -36,7 +36,8 @@ describe('Task Lifecycle Integration Tests', () => { cancel: {} } }, - taskStore + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() } ); diff --git a/src/server/index.test.ts b/src/server/index.test.ts index cf754c9c9..516a5369b 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -23,7 +23,7 @@ import { } from '../types.js'; import { Server } from './index.js'; import { McpServer } from './mcp.js'; -import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; import { CallToolRequestSchema, CallToolResultSchema } from '../types.js'; import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js'; import type { AnyObjectSchema } from './zod-compat.js'; diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 5dffe6cc8..1f58eb861 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -14,21 +14,22 @@ import { Task, TaskCreationParams } from '../types.js'; -import { Protocol, mergeCapabilities, TaskMessageQueue } from './protocol.js'; +import { Protocol, mergeCapabilities } from './protocol.js'; import { Transport, TransportSendOptions } from './transport.js'; -import { TaskStore } from './task.js'; +import { TaskStore, TaskMessageQueue } from './task.js'; import { MockInstance, vi } from 'vitest'; import { JSONRPCResponse, JSONRPCRequest, JSONRPCError } from '../types.js'; import { ErrorMessage, ResponseMessage, toArrayAsync } from './responseMessage.js'; +import { InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; // Type helper for accessing private Protocol properties in tests interface TestProtocol { - _taskMessageQueues: Map; + _taskMessageQueue?: TaskMessageQueue; _taskResultWaiters: Map void>>; _requestResolvers: Map void>; _responseHandlers: Map void>; _taskProgressTokens: Map; - _clearTaskQueue: (taskId: string) => void; + _clearTaskQueue: (taskId: string, sessionId?: string) => Promise; requestTaskStore: (request: Request, authInfo: unknown) => TaskStore; } @@ -795,15 +796,16 @@ describe('protocol tests', () => { }); }); -describe('TaskMessageQueue', () => { +describe('InMemoryTaskMessageQueue', () => { let queue: TaskMessageQueue; + const taskId = 'test-task-id'; beforeEach(() => { - queue = new TaskMessageQueue(); + queue = new InMemoryTaskMessageQueue(); }); describe('enqueue/dequeue maintains FIFO order', () => { - it('should maintain FIFO order for multiple messages', () => { + it('should maintain FIFO order for multiple messages', async () => { const msg1 = { type: 'notification' as const, message: { jsonrpc: '2.0' as const, method: 'test1' }, @@ -820,106 +822,22 @@ describe('TaskMessageQueue', () => { timestamp: 3 }; - queue.enqueue(msg1); - queue.enqueue(msg2); - queue.enqueue(msg3); + await queue.enqueue(taskId, msg1); + await queue.enqueue(taskId, msg2); + await queue.enqueue(taskId, msg3); - expect(queue!.dequeue()).toEqual(msg1); - expect(queue!.dequeue()).toEqual(msg2); - expect(queue!.dequeue()).toEqual(msg3); + expect(await queue.dequeue(taskId)).toEqual(msg1); + expect(await queue.dequeue(taskId)).toEqual(msg2); + expect(await queue.dequeue(taskId)).toEqual(msg3); }); - it('should return undefined when dequeuing from empty queue', () => { - expect(queue!.dequeue()).toBeUndefined(); - }); - }); - - describe('clear operation', () => { - it('should remove all messages from queue', () => { - queue.enqueue({ - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: 'test1' }, - timestamp: 1 - }); - queue.enqueue({ - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: 'test2' }, - timestamp: 2 - }); - queue.enqueue({ - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: 'test3' }, - timestamp: 3 - }); - - expect(queue!.size()).toBe(3); - - queue.clear(); - - expect(queue!.size()).toBe(0); - expect(queue.isEmpty()).toBe(true); - expect(queue!.dequeue()).toBeUndefined(); - }); - - it('should work on empty queue', () => { - expect(() => queue.clear()).not.toThrow(); - expect(queue.isEmpty()).toBe(true); - }); - }); - - describe('isEmpty and size methods', () => { - it('should return true for empty queue', () => { - expect(queue.isEmpty()).toBe(true); - expect(queue!.size()).toBe(0); - }); - - it('should return false after enqueuing', () => { - queue.enqueue({ - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: 'test' }, - timestamp: 1 - }); - expect(queue.isEmpty()).toBe(false); - expect(queue!.size()).toBe(1); - }); - - it('should return correct size for multiple messages', () => { - for (let i = 0; i < 5; i++) { - queue.enqueue({ - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: `test${i}` }, - timestamp: i - }); - } - expect(queue!.size()).toBe(5); - expect(queue.isEmpty()).toBe(false); - }); - - it('should update size correctly after dequeue', () => { - queue.enqueue({ - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: 'test1' }, - timestamp: 1 - }); - queue.enqueue({ - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: 'test2' }, - timestamp: 2 - }); - expect(queue!.size()).toBe(2); - - queue!.dequeue(); - expect(queue!.size()).toBe(1); - expect(queue.isEmpty()).toBe(false); - - queue!.dequeue(); - expect(queue!.size()).toBe(0); - expect(queue.isEmpty()).toBe(true); + it('should return undefined when dequeuing from empty queue', async () => { + expect(await queue.dequeue(taskId)).toBeUndefined(); }); }); describe('dequeueAll operation', () => { - it('should return all messages in FIFO order', () => { + it('should return all messages in FIFO order', async () => { const msg1 = { type: 'notification' as const, message: { jsonrpc: '2.0' as const, method: 'test1' }, @@ -936,39 +854,35 @@ describe('TaskMessageQueue', () => { timestamp: 3 }; - queue.enqueue(msg1); - queue.enqueue(msg2); - queue.enqueue(msg3); + await queue.enqueue(taskId, msg1); + await queue.enqueue(taskId, msg2); + await queue.enqueue(taskId, msg3); - const allMessages = queue.dequeueAll(); + const allMessages = await queue.dequeueAll(taskId); expect(allMessages).toEqual([msg1, msg2, msg3]); - expect(queue.isEmpty()).toBe(true); - expect(queue!.size()).toBe(0); }); - it('should return empty array for empty queue', () => { - const allMessages = queue.dequeueAll(); + it('should return empty array for empty queue', async () => { + const allMessages = await queue.dequeueAll(taskId); expect(allMessages).toEqual([]); - expect(queue.isEmpty()).toBe(true); }); - it('should clear queue after dequeueAll', () => { - queue.enqueue({ + it('should clear queue after dequeueAll', async () => { + await queue.enqueue(taskId, { type: 'notification' as const, message: { jsonrpc: '2.0' as const, method: 'test1' }, timestamp: 1 }); - queue.enqueue({ + await queue.enqueue(taskId, { type: 'notification' as const, message: { jsonrpc: '2.0' as const, method: 'test2' }, timestamp: 2 }); - queue.dequeueAll(); + await queue.dequeueAll(taskId); - expect(queue!.dequeue()).toBeUndefined(); - expect(queue!.size()).toBe(0); + expect(await queue.dequeue(taskId)).toBeUndefined(); }); }); }); @@ -1077,7 +991,7 @@ describe('Task-based execution', () => { protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} protected assertTaskHandlerCapability(): void {} - })(); + })({ taskStore: createMockTaskStore(), taskMessageQueue: new InMemoryTaskMessageQueue() }); }); describe('request with task metadata', () => { @@ -1215,9 +1129,8 @@ describe('Task-based execution', () => { expect(sendSpy).not.toHaveBeenCalled(); // Verify the message was queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueues.get('parent-task-123'); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; expect(queue).toBeDefined(); - expect(queue!.size()).toBeGreaterThan(0); }); it('should work with notification method', async () => { @@ -1240,11 +1153,10 @@ describe('Task-based execution', () => { expect(sendSpy).not.toHaveBeenCalled(); // Verify the message was queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueues.get('parent-task-456'); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; expect(queue).toBeDefined(); - expect(queue!.size()).toBe(1); - const queuedMessage = queue!.dequeue(); + const queuedMessage = await queue!.dequeue('parent-task-456'); expect(queuedMessage).toBeDefined(); expect(queuedMessage?.type).toBe('notification'); expect(queuedMessage?.message.method).toBe('notifications/message'); @@ -1289,11 +1201,10 @@ describe('Task-based execution', () => { expect(sendSpy).not.toHaveBeenCalled(); // Verify the message was queued with all metadata combined - const queue = (protocol as unknown as TestProtocol)._taskMessageQueues.get('parent-task'); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; expect(queue).toBeDefined(); - expect(queue!.size()).toBeGreaterThan(0); - const queuedMessage = queue!.dequeue(); + const queuedMessage = await queue!.dequeue('parent-task'); expect(queuedMessage).toBeDefined(); expect(queuedMessage?.type).toBe('request'); expect(queuedMessage?.message.params).toMatchObject({ @@ -2024,7 +1935,7 @@ describe('Task-based execution', () => { protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -2076,13 +1987,10 @@ describe('Task-based execution', () => { // Verify the notification was QUEUED (not sent via transport) // Messages with relatedTask metadata should be queued for delivery via tasks/result // to prevent duplicate delivery for bidirectional transports - const queues = (serverProtocol as unknown as TestProtocol)._taskMessageQueues; - expect(queues.has('parent-task-123')).toBe(true); - - const queue = queues.get('parent-task-123')!; - expect(queue.size()).toBeGreaterThan(0); + const queue = (serverProtocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); - const queuedMessage = queue.dequeue(); + const queuedMessage = await queue!.dequeue('parent-task-123'); expect(queuedMessage).toBeDefined(); expect(queuedMessage?.type).toBe('notification'); expect(queuedMessage?.message.method).toBe('notifications/message'); @@ -3107,7 +3015,7 @@ describe('Message interception for task-related notifications', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3126,11 +3034,10 @@ describe('Message interception for task-related notifications', () => { ); // Access the private queue to verify the message was queued - const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); + const queue = (server as unknown as TestProtocol)._taskMessageQueue; expect(queue).toBeDefined(); - expect(queue!.size()).toBe(1); - const queuedMessage = queue!.dequeue(); + const queuedMessage = await queue!.dequeue(task.taskId); expect(queuedMessage).toBeDefined(); expect(queuedMessage?.type).toBe('notification'); expect(queuedMessage?.message.method).toBe('notifications/message'); @@ -3146,7 +3053,7 @@ describe('Message interception for task-related notifications', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3156,9 +3063,9 @@ describe('Message interception for task-related notifications', () => { params: { level: 'info', data: 'test message' } }); - // Verify no queues were created - const queues = (server as unknown as TestProtocol)._taskMessageQueues; - expect(queues.size).toBe(0); + // Verify message was not queued (notification without metadata goes through transport) + // We can't directly check the queue, but we know it wasn't queued because + // notifications without relatedTask metadata are sent via transport, not queued }); it('should notify task result waiters after queuing', async () => { @@ -3170,7 +3077,7 @@ describe('Message interception for task-related notifications', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3211,7 +3118,7 @@ describe('Message interception for task-related notifications', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, maxTaskQueueSize: 100 }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); await server.connect(transport); @@ -3231,10 +3138,6 @@ describe('Message interception for task-related notifications', () => { ); } - // Verify queue is at max capacity - const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); - expect(queue!.size()).toBe(100); - // Try to add one more message - should throw and fail the task await expect( server.notification( @@ -3248,11 +3151,8 @@ describe('Message interception for task-related notifications', () => { ) ).rejects.toThrow(McpError); - // Verify the task was failed - expect(taskStore.updateTaskStatus).toHaveBeenCalledWith(task.taskId, 'failed', 'Task message queue overflow'); - - // Verify the queue was cleared - expect(queue!.size()).toBe(0); + // Verify the task was failed with overflow error + expect(taskStore.updateTaskStatus).toHaveBeenCalledWith(task.taskId, 'failed', expect.stringContaining('overflow'), undefined); }); it('should extract task ID correctly from metadata', async () => { @@ -3264,7 +3164,7 @@ describe('Message interception for task-related notifications', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3282,9 +3182,10 @@ describe('Message interception for task-related notifications', () => { ); // Verify the message was queued under the correct task ID - const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(taskId); + const queue = (server as unknown as TestProtocol)._taskMessageQueue; expect(queue).toBeDefined(); - expect(queue!.size()).toBe(1); + const queuedMessage = await queue!.dequeue(taskId); + expect(queuedMessage).toBeDefined(); }); it('should preserve message order when queuing multiple notifications', async () => { @@ -3296,7 +3197,7 @@ describe('Message interception for task-related notifications', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3317,11 +3218,12 @@ describe('Message interception for task-related notifications', () => { } // Verify messages are in FIFO order - const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); - expect(queue!.size()).toBe(5); + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); for (let i = 0; i < 5; i++) { - const message = queue!.dequeue(); + const message = await queue!.dequeue(task.taskId); + expect(message).toBeDefined(); expect(message!.message.params!.data).toBe(`message ${i}`); } }); @@ -3337,7 +3239,7 @@ describe('Message interception for task-related requests', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3357,11 +3259,10 @@ describe('Message interception for task-related requests', () => { ); // Access the private queue to verify the message was queued - const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); + const queue = (server as unknown as TestProtocol)._taskMessageQueue; expect(queue).toBeDefined(); - expect(queue!.size()).toBe(1); - const queuedMessage = queue!.dequeue(); + const queuedMessage = await queue!.dequeue(task.taskId); expect(queuedMessage).toBeDefined(); expect(queuedMessage?.type).toBe('request'); expect(queuedMessage?.message.method).toBe('ping'); @@ -3388,7 +3289,7 @@ describe('Message interception for task-related requests', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3401,9 +3302,9 @@ describe('Message interception for task-related requests', () => { z.object({}) ); - // Verify no queues were created - const queues = (server as unknown as TestProtocol)._taskMessageQueues; - expect(queues.size).toBe(0); + // Verify queue exists (but we don't track size in the new API) + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); // Clean up - send a response transport.onmessage?.({ @@ -3424,7 +3325,7 @@ describe('Message interception for task-related requests', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3452,13 +3353,16 @@ describe('Message interception for task-related requests', () => { } ); + // Wait for the request to be queued and waiter to be called + await new Promise(resolve => setTimeout(resolve, 10)); + // Verify the waiter was called expect(waiterCalled).toBe(true); expect(waiters.has(task.taskId)).toBe(false); // Waiters should be cleared // Clean up - const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); - const queuedMessage = queue!.dequeue(); + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queuedMessage = await queue!.dequeue(task.taskId); transport.onmessage?.({ jsonrpc: '2.0', id: queuedMessage!.originalRequestId!, @@ -3477,7 +3381,7 @@ describe('Message interception for task-related requests', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3501,8 +3405,8 @@ describe('Message interception for task-related requests', () => { expect(resolvers.size).toBe(1); // Get the request ID from the queue - const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); - const queuedMessage = queue!.dequeue(); + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queuedMessage = await queue!.dequeue(task.taskId); const requestId = queuedMessage!.originalRequestId!; expect(resolvers.has(requestId)).toBe(true); @@ -3529,7 +3433,7 @@ describe('Message interception for task-related requests', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3549,8 +3453,8 @@ describe('Message interception for task-related requests', () => { ); // Get the request ID from the queue - const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); - const queuedMessage = queue!.dequeue(); + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queuedMessage = await queue!.dequeue(task.taskId); const requestId = queuedMessage!.originalRequestId!; // Send a response @@ -3574,7 +3478,7 @@ describe('Message interception for task-related requests', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); const errors: Error[] = []; server.onerror = (error: Error) => { @@ -3599,8 +3503,8 @@ describe('Message interception for task-related requests', () => { ); // Get the request ID from the queue - const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); - const queuedMessage = queue!.dequeue(); + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queuedMessage = await queue!.dequeue(task.taskId); const requestId = queuedMessage!.originalRequestId!; // Manually delete the response handler to simulate missing resolver @@ -3630,7 +3534,7 @@ describe('Message interception for task-related requests', () => { protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, maxTaskQueueSize: 100 }); + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); await server.connect(transport); @@ -3657,10 +3561,6 @@ describe('Message interception for task-related requests', () => { promises.push(promise); } - // Verify queue is at max capacity - const queue = (server as unknown as TestProtocol)._taskMessageQueues.get(task.taskId); - expect(queue!.size()).toBe(100); - // Try to add one more request - should throw and fail the task await expect( server.request( @@ -3675,11 +3575,8 @@ describe('Message interception for task-related requests', () => { ) ).rejects.toThrow(McpError); - // Verify the task was failed - expect(taskStore.updateTaskStatus).toHaveBeenCalledWith(task.taskId, 'failed', 'Task message queue overflow'); - - // Verify the queue was cleared - expect(queue!.size()).toBe(0); + // Verify the task was failed with overflow error + expect(taskStore.updateTaskStatus).toHaveBeenCalledWith(task.taskId, 'failed', expect.stringContaining('overflow'), undefined); }); }); @@ -3697,7 +3594,7 @@ describe('Message Interception', () => { protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); }); describe('messages with relatedTask metadata are queued', () => { @@ -3717,14 +3614,11 @@ describe('Message Interception', () => { } ); - // Access the private _taskMessageQueues to verify the message was queued - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.has('task-123')).toBe(true); - - const queue = queues.get('task-123')!; - expect(queue!.size()).toBe(1); + // Access the private _taskMessageQueue to verify the message was queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); - const queuedMessage = queue!.dequeue(); + const queuedMessage = await queue!.dequeue('task-123'); expect(queuedMessage).toBeDefined(); expect(queuedMessage!.type).toBe('notification'); expect(queuedMessage!.message.method).toBe('notifications/message'); @@ -3749,14 +3643,11 @@ describe('Message Interception', () => { } ); - // Access the private _taskMessageQueues to verify the message was queued - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.has('task-456')).toBe(true); - - const queue = queues.get('task-456')!; - expect(queue!.size()).toBe(1); + // Access the private _taskMessageQueue to verify the message was queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); - const queuedMessage = queue!.dequeue(); + const queuedMessage = await queue!.dequeue('task-456'); expect(queuedMessage).toBeDefined(); expect(queuedMessage!.type).toBe('request'); expect(queuedMessage!.message.method).toBe('test/request'); @@ -3782,9 +3673,11 @@ describe('Message Interception', () => { params: { level: 'info', data: 'test message' } }); - // Access the private _taskMessageQueues to verify no queue was created - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.size).toBe(0); + // Access the private _taskMessageQueue to verify no messages were queued + // Since we can't check if queues exist without messages, we verify that + // attempting to dequeue returns undefined (no messages queued) + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); }); it('should not queue requests without relatedTask metadata', async () => { @@ -3802,9 +3695,11 @@ describe('Message Interception', () => { mockSchema ); - // Access the private _taskMessageQueues to verify no queue was created - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.size).toBe(0); + // Access the private _taskMessageQueue to verify no messages were queued + // Since we can't check if queues exist without messages, we verify that + // attempting to dequeue returns undefined (no messages queued) + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); // Clean up the pending request const requestId = (sendSpy.mock.calls[0][0] as JSONRPCResponse).id; @@ -3837,9 +3732,13 @@ describe('Message Interception', () => { ); // Verify the message was queued under the correct task ID - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.has(taskId)).toBe(true); - expect(queues.has('wrong-task-id')).toBe(false); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Verify a message was queued for this task + const queuedMessage = await queue!.dequeue(taskId); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.message.method).toBe('notifications/message'); }); it('should extract correct task ID from relatedTask metadata for requests', async () => { @@ -3863,13 +3762,13 @@ describe('Message Interception', () => { ); // Verify the message was queued under the correct task ID - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.has(taskId)).toBe(true); - expect(queues.has('wrong-task-id')).toBe(false); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); // Clean up the pending request - const queue = queues.get(taskId)!; - const queuedMessage = queue!.dequeue(); + const queuedMessage = await queue!.dequeue(taskId); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.message.method).toBe('test/request'); transport.onmessage?.({ jsonrpc: '2.0', id: (queuedMessage!.message as JSONRPCRequest).id, @@ -3887,75 +3786,79 @@ describe('Message Interception', () => { await protocol.notification({ method: 'test3', params: {} }, { relatedTask: { taskId: 'task-A' } }); // Verify messages are queued under correct task IDs - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.has('task-A')).toBe(true); - expect(queues.has('task-B')).toBe(true); - - const queueA = queues.get('task-A')!; - const queueB = queues.get('task-B')!; + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); - expect(queueA.size()).toBe(2); // Two messages for task-A - expect(queueB.size()).toBe(1); // One message for task-B + // Verify two messages for task-A + const msg1A = await queue!.dequeue('task-A'); + const msg2A = await queue!.dequeue('task-A'); + const msg3A = await queue!.dequeue('task-A'); // Should be undefined + expect(msg1A).toBeDefined(); + expect(msg2A).toBeDefined(); + expect(msg3A).toBeUndefined(); + + // Verify one message for task-B + const msg1B = await queue!.dequeue('task-B'); + const msg2B = await queue!.dequeue('task-B'); // Should be undefined + expect(msg1B).toBeDefined(); + expect(msg2B).toBeUndefined(); }); }); describe('queue creation on first message', () => { - it('should create queue on first message for a task', async () => { + it('should queue messages for a task', async () => { await protocol.connect(transport); - // Verify no queues exist initially - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.size).toBe(0); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); // Send first message for a task await protocol.notification({ method: 'test', params: {} }, { relatedTask: { taskId: 'new-task' } }); - // Verify queue was created - expect(queues.has('new-task')).toBe(true); - expect(queues.size).toBe(1); + // Verify message was queued + const msg = await queue!.dequeue('new-task'); + expect(msg).toBeDefined(); + expect(msg?.message.method).toBe('test'); }); - it('should reuse existing queue for subsequent messages', async () => { + it('should queue multiple messages for the same task', async () => { await protocol.connect(transport); - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); // Send first message await protocol.notification({ method: 'test1', params: {} }, { relatedTask: { taskId: 'reuse-task' } }); - const firstQueue = queues.get('reuse-task'); - expect(firstQueue).toBeDefined(); - expect(firstQueue!.size()).toBe(1); - // Send second message await protocol.notification({ method: 'test2', params: {} }, { relatedTask: { taskId: 'reuse-task' } }); - const secondQueue = queues.get('reuse-task'); - - // Should be the same queue instance - expect(secondQueue).toBe(firstQueue); - expect(secondQueue!.size()).toBe(2); + // Verify both messages were queued in order + const msg1 = await queue!.dequeue('reuse-task'); + const msg2 = await queue!.dequeue('reuse-task'); + expect(msg1).toBeDefined(); + expect(msg1?.message.method).toBe('test1'); + expect(msg2).toBeDefined(); + expect(msg2?.message.method).toBe('test2'); }); - it('should create separate queues for different tasks', async () => { + it('should queue messages for different tasks separately', async () => { await protocol.connect(transport); - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); // Send messages for different tasks await protocol.notification({ method: 'test1', params: {} }, { relatedTask: { taskId: 'task-1' } }); await protocol.notification({ method: 'test2', params: {} }, { relatedTask: { taskId: 'task-2' } }); - // Verify separate queues were created - expect(queues.size).toBe(2); - expect(queues.has('task-1')).toBe(true); - expect(queues.has('task-2')).toBe(true); - - const queue1 = queues.get('task-1')!; - const queue2 = queues.get('task-2')!; - - // Verify they are different queue instances - expect(queue1).not.toBe(queue2); + // Verify messages are queued separately + const msg1 = await queue!.dequeue('task-1'); + const msg2 = await queue!.dequeue('task-2'); + expect(msg1).toBeDefined(); + expect(msg1?.message.method).toBe('test1'); + expect(msg2).toBeDefined(); + expect(msg2?.message.method).toBe('test2'); }); }); @@ -3973,11 +3876,11 @@ describe('Message Interception', () => { { relatedTask } ); - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - const queue = queues.get('task-meta-123')!; - const queuedMessage = queue!.dequeue(); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queuedMessage = await queue!.dequeue('task-meta-123'); // Verify the metadata is preserved in the queued message + expect(queuedMessage).toBeDefined(); expect(queuedMessage!.message.params!._meta).toBeDefined(); expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual(relatedTask); }); @@ -3997,11 +3900,11 @@ describe('Message Interception', () => { { relatedTask } ); - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - const queue = queues.get('task-meta-456')!; - const queuedMessage = queue!.dequeue(); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queuedMessage = await queue!.dequeue('task-meta-456'); // Verify the metadata is preserved in the queued message + expect(queuedMessage).toBeDefined(); expect(queuedMessage!.message.params!._meta).toBeDefined(); expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual(relatedTask); @@ -4033,11 +3936,11 @@ describe('Message Interception', () => { } ); - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - const queue = queues.get('task-preserve-meta')!; - const queuedMessage = queue!.dequeue(); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queuedMessage = await queue!.dequeue('task-preserve-meta'); // Verify both existing and new metadata are preserved + expect(queuedMessage).toBeDefined(); expect(queuedMessage!.message.params!._meta!.customField).toBe('customValue'); expect(queuedMessage!.message.params!._meta!.anotherField).toBe(123); expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ @@ -4061,7 +3964,7 @@ describe('Queue lifecycle management', () => { protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); }); describe('queue cleanup on task completion', () => { @@ -4077,15 +3980,21 @@ describe('Queue lifecycle management', () => { await protocol.notification({ method: 'test/notification', params: { data: 'test2' } }, { relatedTask: { taskId } }); // Verify messages are queued - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.has(taskId)).toBe(true); - expect(queues.get(taskId)!.size()).toBe(2); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Verify messages can be dequeued + const msg1 = await queue!.dequeue(taskId); + const msg2 = await queue!.dequeue(taskId); + expect(msg1).toBeDefined(); + expect(msg2).toBeDefined(); // Directly call the cleanup method (simulating what happens when task reaches terminal status) (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); - // Verify queue is cleared - expect(queues.has(taskId)).toBe(false); + // After cleanup, no more messages should be available + const msg3 = await queue!.dequeue(taskId); + expect(msg3).toBeUndefined(); }); it('should clear queue after delivering messages on tasks/result for completed task', async () => { @@ -4116,9 +4025,10 @@ describe('Queue lifecycle management', () => { await resultPromise; - // Verify queue is cleared after delivery - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.has(taskId)).toBe(false); + // Verify queue is cleared after delivery (no messages available) + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const msg = await queue!.dequeue(taskId); + expect(msg).toBeUndefined(); }); }); @@ -4134,9 +4044,12 @@ describe('Queue lifecycle management', () => { await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); // Verify message is queued - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.has(taskId)).toBe(true); - expect(queues.get(taskId)!.size()).toBe(1); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const msg1 = await queue!.dequeue(taskId); + expect(msg1).toBeDefined(); + + // Re-queue the message for cancellation test + await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); // Mock task as non-terminal mockTaskStore.getTask.mockResolvedValue(task); @@ -4152,8 +4065,9 @@ describe('Queue lifecycle management', () => { // Wait for cancellation to process await new Promise(resolve => setTimeout(resolve, 50)); - // Verify queue is cleared - expect(queues.has(taskId)).toBe(false); + // Verify queue is cleared (no messages available) + const msg2 = await queue!.dequeue(taskId); + expect(msg2).toBeUndefined(); }); it('should reject pending request resolvers when task is cancelled', async () => { @@ -4171,9 +4085,8 @@ describe('Queue lifecycle management', () => { .catch(err => err); // Verify request is queued - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.has(taskId)).toBe(true); - expect(queues.get(taskId)!.size()).toBe(1); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); // Mock task as non-terminal mockTaskStore.getTask.mockResolvedValue(task); @@ -4194,8 +4107,9 @@ describe('Queue lifecycle management', () => { expect(result).toBeInstanceOf(McpError); expect(result.message).toContain('Task cancelled or completed'); - // Verify queue is cleared - expect(queues.has(taskId)).toBe(false); + // Verify queue is cleared (no messages available) + const msg = await queue!.dequeue(taskId); + expect(msg).toBeUndefined(); }); }); @@ -4212,15 +4126,21 @@ describe('Queue lifecycle management', () => { await protocol.notification({ method: 'test/notification', params: { data: 'test2' } }, { relatedTask: { taskId } }); // Verify messages are queued - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.has(taskId)).toBe(true); - expect(queues.get(taskId)!.size()).toBe(2); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Verify messages can be dequeued + const msg1 = await queue!.dequeue(taskId); + const msg2 = await queue!.dequeue(taskId); + expect(msg1).toBeDefined(); + expect(msg2).toBeDefined(); // Directly call the cleanup method (simulating what happens when task reaches terminal status) (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); - // Verify queue is cleared - expect(queues.has(taskId)).toBe(false); + // After cleanup, no more messages should be available + const msg3 = await queue!.dequeue(taskId); + expect(msg3).toBeUndefined(); }); it('should reject pending request resolvers when task fails', async () => { @@ -4238,8 +4158,8 @@ describe('Queue lifecycle management', () => { .catch(err => err); // Verify request is queued - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.has(taskId)).toBe(true); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); // Directly call the cleanup method (simulating what happens when task reaches terminal status) (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); @@ -4249,8 +4169,9 @@ describe('Queue lifecycle management', () => { expect(result).toBeInstanceOf(McpError); expect(result.message).toContain('Task cancelled or completed'); - // Verify queue is cleared - expect(queues.has(taskId)).toBe(false); + // Verify queue is cleared (no messages available) + const msg = await queue!.dequeue(taskId); + expect(msg).toBeUndefined(); }); }); @@ -4282,9 +4203,8 @@ describe('Queue lifecycle management', () => { .catch(err => err); // Verify requests are queued - const queues = (protocol as unknown as TestProtocol)._taskMessageQueues as Map; - expect(queues.has(taskId)).toBe(true); - expect(queues.get(taskId)!.size()).toBe(3); + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); // Directly call the cleanup method (simulating what happens when task reaches terminal status) (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); @@ -4301,8 +4221,9 @@ describe('Queue lifecycle management', () => { expect(result3).toBeInstanceOf(McpError); expect(result3.message).toContain('Task cancelled or completed'); - // Verify queue is cleared - expect(queues.has(taskId)).toBe(false); + // Verify queue is cleared (no messages available) + const msg = await queue!.dequeue(taskId); + expect(msg).toBeUndefined(); }); it('should clean up resolver mappings when rejecting requests', async () => { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 8e5c216e1..d24c6f395 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -45,81 +45,10 @@ import { } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; -import { isTerminal, TaskStore } from './task.js'; +import { isTerminal, TaskStore, TaskMessageQueue, QueuedMessage } from './task.js'; import { getMethodLiteral, parseWithCompat } from '../server/zod-json-schema-compat.js'; import { ResponseMessage } from './responseMessage.js'; -/** - * Represents a message queued for side-channel delivery via tasks/result. - */ -export interface QueuedMessage { - /** Type of message */ - type: 'request' | 'notification'; - /** The actual JSONRPC message */ - message: JSONRPCRequest | JSONRPCNotification; - /** When it was queued */ - timestamp: number; - /** For requests: resolver to call when response is received */ - responseResolver?: (response: JSONRPCResponse | Error) => void; - /** For requests: the original request ID for response routing */ - originalRequestId?: RequestId; -} - -/** - * A per-task FIFO queue for server-initiated messages that will be delivered - * through the tasks/result response stream. - */ -export class TaskMessageQueue { - private messages: QueuedMessage[] = []; - - /** - * Adds a message to the end of the queue. - * @param message The message to enqueue - */ - enqueue(message: QueuedMessage): void { - this.messages.push(message); - } - - /** - * Removes and returns the first message from the queue. - * @returns The first message, or undefined if the queue is empty - */ - dequeue(): QueuedMessage | undefined { - return this.messages.shift(); - } - - /** - * Removes and returns all messages from the queue. - * @returns Array of all messages that were in the queue - */ - dequeueAll(): QueuedMessage[] { - const allMessages = this.messages; - this.messages = []; - return allMessages; - } - - /** - * Removes all messages from the queue. - */ - clear(): void { - this.messages = []; - } - - /** - * Returns the number of messages in the queue. - */ - size(): number { - return this.messages.length; - } - - /** - * Checks if the queue is empty. - */ - isEmpty(): boolean { - return this.messages.length === 0; - } -} - /** * Callback for progress notifications. */ @@ -149,6 +78,11 @@ export type ProtocolOptions = { * and provides task storage capabilities to request handlers. */ taskStore?: TaskStore; + /** + * Optional task message queue implementation for managing server-initiated messages + * that will be delivered through the tasks/result response stream. + */ + taskMessageQueue?: TaskMessageQueue; /** * Default polling interval (in milliseconds) for task status checks when no pollInterval * is provided by the server. Defaults to 5000ms if not specified. @@ -384,9 +318,9 @@ export abstract class Protocol = new Map(); private _taskStore?: TaskStore; + private _taskMessageQueue?: TaskMessageQueue; - // Task message queues for side-channel delivery - private _taskMessageQueues: Map = new Map(); + // Task result waiters for side-channel delivery in tasks private _taskResultWaiters: Map void>> = new Map(); private _requestResolvers: Map void> = new Map(); @@ -431,6 +365,7 @@ export abstract class Protocol { const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); @@ -449,13 +384,11 @@ export abstract class Protocol { const handleTaskResult = async (): Promise => { const taskId = request.params.taskId; - const queue = this._taskMessageQueues.get(taskId); // Deliver queued messages - if (queue && !queue.isEmpty()) { - while (!queue.isEmpty()) { - const queuedMessage = queue.dequeue()!; - + if (this._taskMessageQueue) { + let queuedMessage: QueuedMessage | undefined; + while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, extra.sessionId))) { // Strip relatedTask metadata when dequeuing for delivery // The metadata was used for queuing, but shouldn't be sent to the client const messageToSend = { ...queuedMessage.message }; @@ -468,15 +401,14 @@ export abstract class Protocol | undefined; if (queuedMessage.type === 'request' && queuedMessage.responseResolver) { - // Wait for response before continuing to next message - await new Promise((resolve, reject) => { - const originalResolver = queuedMessage.responseResolver!; + // Capture in const to satisfy TypeScript's flow analysis + const msg = queuedMessage; + responsePromise = new Promise((resolve, reject) => { + const originalResolver = msg.responseResolver!; const wrappedResolver = (response: JSONRPCResponse | Error) => { // First, deliver the response to the task handler originalResolver(response); @@ -488,11 +420,20 @@ export abstract class Protocol { + // Notify any waiting tasks/result calls + this._notifyTaskResultWaiters(relatedTaskId); + }) + .catch(error => { + this._cleanupTimeout(messageId); + reject(error); }); - // Notify any waiting tasks/result calls - this._notifyTaskResultWaiters(relatedTaskId); - } catch (error) { - this._cleanupTimeout(messageId); - reject(error); - return; - } - // Don't send through transport - queued messages are delivered via tasks/result only // This prevents duplicate delivery for bidirectional transports } else { @@ -1275,7 +1215,7 @@ export abstract class Protocol { + // Task message queues are only used when taskStore is configured + if (!this._taskStore || !this._taskMessageQueue) { + throw new McpError(ErrorCode.InternalError, 'Cannot enqueue task message: taskStore and taskMessageQueue are not configured'); } const maxQueueSize = this._options?.maxTaskQueueSize; - if (maxQueueSize !== undefined && queue.size() >= maxQueueSize) { - const errorMessage = `Task message queue overflow: queue size (${queue.size()}) exceeds maximum (${maxQueueSize})`; - - // Log the error for debugging - this._onerror(new Error(errorMessage)); - this._taskStore?.updateTaskStatus(taskId, 'failed', 'Task message queue overflow').catch(err => this._onerror(err)); - this._clearTaskQueue(taskId); - - throw new McpError(ErrorCode.InternalError, 'Task message queue overflow'); + try { + await this._taskMessageQueue.enqueue(taskId, message, sessionId, maxQueueSize); + } catch (error) { + // Enqueue failed (e.g., queue overflow, storage error) - fail the task and clear the queue + this._onerror(error as Error); + const errorMessage = error instanceof Error ? error.message : 'Task message enqueue failed'; + this._taskStore.updateTaskStatus(taskId, 'failed', errorMessage, sessionId).catch(err => this._onerror(err)); + await this._clearTaskQueue(taskId, sessionId); + throw new McpError(ErrorCode.InternalError, `Failed to enqueue task message: ${errorMessage}`); } - - queue.enqueue(message); } /** * Clears the message queue for a task and rejects any pending request resolvers. * @param taskId The task ID whose queue should be cleared + * @param sessionId Optional session ID for binding the operation to a specific session */ - private _clearTaskQueue(taskId: string): void { - const queue = this._taskMessageQueues.get(taskId); - if (queue) { + private async _clearTaskQueue(taskId: string, sessionId?: string): Promise { + if (this._taskMessageQueue) { // Reject any pending request resolvers - for (const message of queue.dequeueAll()) { + const messages = await this._taskMessageQueue.dequeueAll(taskId, sessionId); + for (const message of messages) { if (message.type === 'request' && message.responseResolver && message.originalRequestId !== undefined) { message.responseResolver(new McpError(ErrorCode.InternalError, 'Task cancelled or completed')); // Clean up the resolver mapping this._requestResolvers.delete(message.originalRequestId); } } - this._taskMessageQueues.delete(taskId); } } @@ -1544,13 +1482,10 @@ export abstract class Protocol { try { - const task = await this._taskStore?.getTask(taskId); - if (task && (isTerminal(task.status) || this._taskMessageQueues.get(taskId)?.size())) { - clearInterval(pollInterval); - this._notifyTaskResultWaiters(taskId); - } + this._notifyTaskResultWaiters(taskId); } catch { // Ignore errors during polling } diff --git a/src/shared/task-listing.test.ts b/src/shared/task-listing.test.ts index 9651df23c..975706070 100644 --- a/src/shared/task-listing.test.ts +++ b/src/shared/task-listing.test.ts @@ -2,7 +2,7 @@ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; import { InMemoryTransport } from '../inMemory.js'; import { Client } from '../client/index.js'; import { Server } from '../server/index.js'; -import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; describe('Task Listing with Pagination', () => { let client: Client; @@ -51,7 +51,8 @@ describe('Task Listing with Pagination', () => { } } }, - taskStore + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() } ); diff --git a/src/shared/task.ts b/src/shared/task.ts index c7946006a..9c2557f55 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -1,4 +1,68 @@ -import { Task, TaskCreationParams, Request, RequestId, Result } from '../types.js'; +import { Task, TaskCreationParams, Request, RequestId, Result, JSONRPCRequest, JSONRPCNotification, JSONRPCResponse } from '../types.js'; + +/** + * Represents a message queued for side-channel delivery via tasks/result. + */ +export interface QueuedMessage { + /** Type of message */ + type: 'request' | 'notification'; + /** The actual JSONRPC message */ + message: JSONRPCRequest | JSONRPCNotification; + /** When it was queued */ + timestamp: number; + /** For requests: resolver to call when response is received */ + responseResolver?: (response: JSONRPCResponse | Error) => void; + /** For requests: the original request ID for response routing */ + originalRequestId?: RequestId; +} + +/** + * Interface for managing per-task FIFO message queues. + * + * Similar to TaskStore, this allows pluggable queue implementations + * (in-memory, Redis, other distributed queues, etc.) for server-initiated + * messages that will be delivered through the tasks/result response stream. + * + * Each method accepts taskId and optional sessionId parameters to enable + * a single queue instance to manage messages for multiple tasks, with + * isolation based on task ID and session ID. + * + * All methods are async to support external storage implementations. + * + * Performance Notes: + * - enqueue() atomically enforces maxSize to prevent race conditions + * - dequeue() returns undefined when empty, eliminating need for isEmpty() checks + * - dequeueAll() is used when tasks are cancelled/failed to reject pending resolvers + */ +export interface TaskMessageQueue { + /** + * Adds a message to the end of the queue for a specific task. + * Atomically checks queue size and throws if maxSize would be exceeded. + * @param taskId The task identifier + * @param message The message to enqueue + * @param sessionId Optional session ID for binding the operation to a specific session + * @param maxSize Optional maximum queue size - if specified and queue is full, throws an error + * @throws Error if maxSize is specified and would be exceeded + */ + enqueue(taskId: string, message: QueuedMessage, sessionId?: string, maxSize?: number): Promise; + + /** + * Removes and returns the first message from the queue for a specific task. + * @param taskId The task identifier + * @param sessionId Optional session ID for binding the query to a specific session + * @returns The first message, or undefined if the queue is empty + */ + dequeue(taskId: string, sessionId?: string): Promise; + + /** + * Removes and returns all messages from the queue for a specific task. + * Used when tasks are cancelled or failed to reject any pending request resolvers. + * @param taskId The task identifier + * @param sessionId Optional session ID for binding the query to a specific session + * @returns Array of all messages that were in the queue + */ + dequeueAll(taskId: string, sessionId?: string): Promise; +} /** * Interface for storing and retrieving task state and results. From 86cc5cc057199f67a208fd5528d36fd0748fdf4e Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Fri, 21 Nov 2025 13:10:51 -0800 Subject: [PATCH 60/84] Refactor TaskMessageQueue to not store response closures --- src/examples/shared/inMemoryTaskStore.test.ts | 272 +++++- src/shared/protocol.test.ts | 882 ++++++++++++++++-- src/shared/protocol.ts | 85 +- src/shared/task.ts | 64 +- 4 files changed, 1175 insertions(+), 128 deletions(-) diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts index ea74a1460..cba00d987 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -1,6 +1,7 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; -import { InMemoryTaskStore } from './inMemoryTaskStore.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from './inMemoryTaskStore.js'; import { TaskCreationParams, Request } from '../../types.js'; +import { QueuedMessage } from '../../shared/task.js'; describe('InMemoryTaskStore', () => { let store: InMemoryTaskStore; @@ -664,3 +665,272 @@ describe('InMemoryTaskStore', () => { }); }); }); + +describe('InMemoryTaskMessageQueue', () => { + let queue: InMemoryTaskMessageQueue; + + beforeEach(() => { + queue = new InMemoryTaskMessageQueue(); + }); + + describe('enqueue and dequeue', () => { + it('should enqueue and dequeue request messages', async () => { + const requestMessage: QueuedMessage = { + type: 'request', + message: { + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + timestamp: Date.now() + }; + + await queue.enqueue('task-1', requestMessage); + const dequeued = await queue.dequeue('task-1'); + + expect(dequeued).toEqual(requestMessage); + }); + + it('should enqueue and dequeue notification messages', async () => { + const notificationMessage: QueuedMessage = { + type: 'notification', + message: { + jsonrpc: '2.0', + method: 'notifications/progress', + params: { progress: 50, total: 100 } + }, + timestamp: Date.now() + }; + + await queue.enqueue('task-2', notificationMessage); + const dequeued = await queue.dequeue('task-2'); + + expect(dequeued).toEqual(notificationMessage); + }); + + it('should enqueue and dequeue response messages', async () => { + const responseMessage: QueuedMessage = { + type: 'response', + message: { + jsonrpc: '2.0', + id: 42, + result: { content: [{ type: 'text', text: 'Success' }] } + }, + timestamp: Date.now() + }; + + await queue.enqueue('task-3', responseMessage); + const dequeued = await queue.dequeue('task-3'); + + expect(dequeued).toEqual(responseMessage); + }); + + it('should return undefined when dequeuing from empty queue', async () => { + const dequeued = await queue.dequeue('task-empty'); + expect(dequeued).toBeUndefined(); + }); + + it('should maintain FIFO order for mixed message types', async () => { + const request: QueuedMessage = { + type: 'request', + message: { + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: {} + }, + timestamp: 1000 + }; + + const notification: QueuedMessage = { + type: 'notification', + message: { + jsonrpc: '2.0', + method: 'notifications/progress', + params: {} + }, + timestamp: 2000 + }; + + const response: QueuedMessage = { + type: 'response', + message: { + jsonrpc: '2.0', + id: 1, + result: {} + }, + timestamp: 3000 + }; + + await queue.enqueue('task-fifo', request); + await queue.enqueue('task-fifo', notification); + await queue.enqueue('task-fifo', response); + + expect(await queue.dequeue('task-fifo')).toEqual(request); + expect(await queue.dequeue('task-fifo')).toEqual(notification); + expect(await queue.dequeue('task-fifo')).toEqual(response); + expect(await queue.dequeue('task-fifo')).toBeUndefined(); + }); + }); + + describe('dequeueAll', () => { + it('should dequeue all messages including responses', async () => { + const request: QueuedMessage = { + type: 'request', + message: { + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: {} + }, + timestamp: 1000 + }; + + const response: QueuedMessage = { + type: 'response', + message: { + jsonrpc: '2.0', + id: 1, + result: {} + }, + timestamp: 2000 + }; + + const notification: QueuedMessage = { + type: 'notification', + message: { + jsonrpc: '2.0', + method: 'notifications/progress', + params: {} + }, + timestamp: 3000 + }; + + await queue.enqueue('task-all', request); + await queue.enqueue('task-all', response); + await queue.enqueue('task-all', notification); + + const all = await queue.dequeueAll('task-all'); + + expect(all).toHaveLength(3); + expect(all[0]).toEqual(request); + expect(all[1]).toEqual(response); + expect(all[2]).toEqual(notification); + }); + + it('should return empty array for non-existent task', async () => { + const all = await queue.dequeueAll('non-existent'); + expect(all).toEqual([]); + }); + + it('should clear the queue after dequeueAll', async () => { + const message: QueuedMessage = { + type: 'request', + message: { + jsonrpc: '2.0', + id: 1, + method: 'test', + params: {} + }, + timestamp: Date.now() + }; + + await queue.enqueue('task-clear', message); + await queue.dequeueAll('task-clear'); + + const dequeued = await queue.dequeue('task-clear'); + expect(dequeued).toBeUndefined(); + }); + }); + + describe('queue size limits', () => { + it('should throw when maxSize is exceeded', async () => { + const message: QueuedMessage = { + type: 'request', + message: { + jsonrpc: '2.0', + id: 1, + method: 'test', + params: {} + }, + timestamp: Date.now() + }; + + await queue.enqueue('task-limit', message, undefined, 2); + await queue.enqueue('task-limit', message, undefined, 2); + + await expect(queue.enqueue('task-limit', message, undefined, 2)).rejects.toThrow('Task message queue overflow'); + }); + + it('should allow enqueue when under maxSize', async () => { + const message: QueuedMessage = { + type: 'response', + message: { + jsonrpc: '2.0', + id: 1, + result: {} + }, + timestamp: Date.now() + }; + + await expect(queue.enqueue('task-ok', message, undefined, 5)).resolves.toBeUndefined(); + }); + }); + + describe('task isolation', () => { + it('should isolate messages between different tasks', async () => { + const message1: QueuedMessage = { + type: 'request', + message: { + jsonrpc: '2.0', + id: 1, + method: 'test1', + params: {} + }, + timestamp: 1000 + }; + + const message2: QueuedMessage = { + type: 'response', + message: { + jsonrpc: '2.0', + id: 2, + result: {} + }, + timestamp: 2000 + }; + + await queue.enqueue('task-a', message1); + await queue.enqueue('task-b', message2); + + expect(await queue.dequeue('task-a')).toEqual(message1); + expect(await queue.dequeue('task-b')).toEqual(message2); + expect(await queue.dequeue('task-a')).toBeUndefined(); + expect(await queue.dequeue('task-b')).toBeUndefined(); + }); + }); + + describe('response message error handling', () => { + it('should handle response messages with errors', async () => { + const errorResponse: QueuedMessage = { + type: 'error', + message: { + jsonrpc: '2.0', + id: 1, + error: { + code: -32600, + message: 'Invalid Request' + } + }, + timestamp: Date.now() + }; + + await queue.enqueue('task-error', errorResponse); + const dequeued = await queue.dequeue('task-error'); + + expect(dequeued).toEqual(errorResponse); + expect(dequeued?.type).toBe('error'); + }); + }); +}); diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 1f58eb861..042baea7e 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -16,7 +16,7 @@ import { } from '../types.js'; import { Protocol, mergeCapabilities } from './protocol.js'; import { Transport, TransportSendOptions } from './transport.js'; -import { TaskStore, TaskMessageQueue } from './task.js'; +import { TaskStore, TaskMessageQueue, QueuedMessage, QueuedNotification, QueuedRequest } from './task.js'; import { MockInstance, vi } from 'vitest'; import { JSONRPCResponse, JSONRPCRequest, JSONRPCError } from '../types.js'; import { ErrorMessage, ResponseMessage, toArrayAsync } from './responseMessage.js'; @@ -123,6 +123,16 @@ function assertErrorResponse(o: ResponseMessage): asserts o is ErrorMess expect(o.type).toBe('error'); } +function assertQueuedNotification(o?: QueuedMessage): asserts o is QueuedNotification { + expect(o).toBeDefined(); + expect(o?.type).toBe('notification'); +} + +function assertQueuedRequest(o?: QueuedMessage): asserts o is QueuedRequest { + expect(o).toBeDefined(); + expect(o?.type).toBe('request'); +} + describe('protocol tests', () => { let protocol: Protocol; let transport: MockTransport; @@ -1157,10 +1167,9 @@ describe('Task-based execution', () => { expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('parent-task-456'); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage?.type).toBe('notification'); - expect(queuedMessage?.message.method).toBe('notifications/message'); - expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: 'parent-task-456' }); + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.method).toBe('notifications/message'); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: 'parent-task-456' }); }); }); @@ -1205,9 +1214,8 @@ describe('Task-based execution', () => { expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('parent-task'); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage?.type).toBe('request'); - expect(queuedMessage?.message.params).toMatchObject({ + assertQueuedRequest(queuedMessage); + expect(queuedMessage.message.params).toMatchObject({ name: 'test-tool', task: { ttl: 60000, @@ -1991,10 +1999,9 @@ describe('Task-based execution', () => { expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('parent-task-123'); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage?.type).toBe('notification'); - expect(queuedMessage?.message.method).toBe('notifications/message'); - expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.method).toBe('notifications/message'); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: 'parent-task-123' }); @@ -3038,10 +3045,9 @@ describe('Message interception for task-related notifications', () => { expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(task.taskId); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage?.type).toBe('notification'); - expect(queuedMessage?.message.method).toBe('notifications/message'); - expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: task.taskId }); + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.method).toBe('notifications/message'); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: task.taskId }); }); it('should not queue notifications without related-task metadata', async () => { @@ -3222,9 +3228,9 @@ describe('Message interception for task-related notifications', () => { expect(queue).toBeDefined(); for (let i = 0; i < 5; i++) { - const message = await queue!.dequeue(task.taskId); - expect(message).toBeDefined(); - expect(message!.message.params!.data).toBe(`message ${i}`); + const queuedMessage = await queue!.dequeue(task.taskId); + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.params!.data).toBe(`message ${i}`); } }); }); @@ -3263,17 +3269,19 @@ describe('Message interception for task-related requests', () => { expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(task.taskId); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage?.type).toBe('request'); - expect(queuedMessage?.message.method).toBe('ping'); - expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: task.taskId }); - expect(queuedMessage?.responseResolver).toBeDefined(); - expect(queuedMessage!.originalRequestId!).toBeDefined(); + assertQueuedRequest(queuedMessage); + expect(queuedMessage.message.method).toBe('ping'); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: task.taskId }); + + // Verify resolver is stored in _requestResolvers map (not in the message) + const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; + const resolvers = (server as unknown as TestProtocol)._requestResolvers; + expect(resolvers.has(requestId)).toBe(true); // Clean up - send a response to prevent hanging promise transport.onmessage?.({ jsonrpc: '2.0', - id: queuedMessage!.originalRequestId!, + id: requestId, result: {} }); @@ -3363,9 +3371,10 @@ describe('Message interception for task-related requests', () => { // Clean up const queue = (server as unknown as TestProtocol)._taskMessageQueue; const queuedMessage = await queue!.dequeue(task.taskId); + const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; transport.onmessage?.({ jsonrpc: '2.0', - id: queuedMessage!.originalRequestId!, + id: requestId, result: {} }); @@ -3407,7 +3416,7 @@ describe('Message interception for task-related requests', () => { // Get the request ID from the queue const queue = (server as unknown as TestProtocol)._taskMessageQueue; const queuedMessage = await queue!.dequeue(task.taskId); - const requestId = queuedMessage!.originalRequestId!; + const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; expect(resolvers.has(requestId)).toBe(true); @@ -3427,13 +3436,14 @@ describe('Message interception for task-related requests', () => { it('should route responses to side-channeled requests', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); + const queue = new InMemoryTaskMessageQueue(); const server = new (class extends Protocol { protected assertCapabilityForMethod(_method: string): void {} protected assertNotificationCapability(_method: string): void {} protected assertRequestHandlerCapability(_method: string): void {} protected assertTaskCapability(_method: string): void {} protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ taskStore, taskMessageQueue: queue }); await server.connect(transport); @@ -3453,17 +3463,37 @@ describe('Message interception for task-related requests', () => { ); // Get the request ID from the queue - const queue = (server as unknown as TestProtocol)._taskMessageQueue; - const queuedMessage = await queue!.dequeue(task.taskId); - const requestId = queuedMessage!.originalRequestId!; + const queuedMessage = await queue.dequeue(task.taskId); + const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; + + // Enqueue a response message to the queue (simulating client sending response back) + await queue.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: requestId, + result: { message: 'pong' } + }, + timestamp: Date.now() + }); - // Send a response + // Simulate a client calling tasks/result which will process the response + // This is done by creating a mock request handler that will trigger the GetTaskPayloadRequest handler + const mockRequestId = 999; transport.onmessage?.({ jsonrpc: '2.0', - id: requestId, - result: { message: 'pong' } + id: mockRequestId, + method: 'tasks/result', + params: { taskId: task.taskId } }); + // Wait for the response to be processed + await new Promise(resolve => setTimeout(resolve, 50)); + + // Mark task as completed + await taskStore.updateTaskStatus(task.taskId, 'completed'); + await taskStore.storeTaskResult(task.taskId, 'completed', { _meta: {} }); + // Verify the response was routed correctly const result = await requestPromise; expect(result).toEqual({ message: 'pong' }); @@ -3505,24 +3535,44 @@ describe('Message interception for task-related requests', () => { // Get the request ID from the queue const queue = (server as unknown as TestProtocol)._taskMessageQueue; const queuedMessage = await queue!.dequeue(task.taskId); - const requestId = queuedMessage!.originalRequestId!; + const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; - // Manually delete the response handler to simulate missing resolver - (server as unknown as TestProtocol)._responseHandlers.delete(requestId); + // Manually delete the resolver to simulate missing resolver + (server as unknown as TestProtocol)._requestResolvers.delete(requestId); - // Send a response - this should trigger the error logging + // Enqueue a response message - this should trigger the error logging when processed + await queue!.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: requestId, + result: { message: 'pong' } + }, + timestamp: Date.now() + }); + + // Simulate a client calling tasks/result which will process the response + const mockRequestId = 888; transport.onmessage?.({ jsonrpc: '2.0', - id: requestId, - result: { message: 'pong' } + id: mockRequestId, + method: 'tasks/result', + params: { taskId: task.taskId } }); - // Wait a bit for the error to be logged - await new Promise(resolve => setTimeout(resolve, 10)); + // Wait for the response to be processed + await new Promise(resolve => setTimeout(resolve, 50)); + + // Mark task as completed + await taskStore.updateTaskStatus(task.taskId, 'completed'); + await taskStore.storeTaskResult(task.taskId, 'completed', { _meta: {} }); + + // Wait a bit more for error to be logged + await new Promise(resolve => setTimeout(resolve, 50)); // Verify error was logged - expect(errors.length).toBe(1); - expect(errors[0].message).toContain('Response handler missing for side-channeled request'); + expect(errors.length).toBeGreaterThanOrEqual(1); + expect(errors.some(e => e.message.includes('Response handler missing for request'))).toBe(true); }); it('should handle queue overflow for requests', async () => { @@ -3619,8 +3669,7 @@ describe('Message Interception', () => { expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('task-123'); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage!.type).toBe('notification'); + assertQueuedNotification(queuedMessage); expect(queuedMessage!.message.method).toBe('notifications/message'); }); @@ -3648,15 +3697,18 @@ describe('Message Interception', () => { expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('task-456'); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage!.type).toBe('request'); - expect(queuedMessage!.message.method).toBe('test/request'); - expect(queuedMessage!.responseResolver).toBeDefined(); + assertQueuedRequest(queuedMessage); + expect(queuedMessage.message.method).toBe('test/request'); + + // Verify resolver is stored in _requestResolvers map (not in the message) + const requestId = queuedMessage.message.id as RequestId; + const resolvers = (protocol as unknown as TestProtocol)._requestResolvers; + expect(resolvers.has(requestId)).toBe(true); // Clean up the pending request transport.onmessage?.({ jsonrpc: '2.0', - id: (queuedMessage!.message as JSONRPCRequest).id, + id: requestId, result: { result: 'success' } }); await requestPromise; @@ -3737,8 +3789,8 @@ describe('Message Interception', () => { // Verify a message was queued for this task const queuedMessage = await queue!.dequeue(taskId); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage?.message.method).toBe('notifications/message'); + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.method).toBe('notifications/message'); }); it('should extract correct task ID from relatedTask metadata for requests', async () => { @@ -3767,11 +3819,11 @@ describe('Message Interception', () => { // Clean up the pending request const queuedMessage = await queue!.dequeue(taskId); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage?.message.method).toBe('test/request'); + assertQueuedRequest(queuedMessage); + expect(queuedMessage.message.method).toBe('test/request'); transport.onmessage?.({ jsonrpc: '2.0', - id: (queuedMessage!.message as JSONRPCRequest).id, + id: queuedMessage.message.id, result: { result: 'success' } }); await requestPromise; @@ -3817,8 +3869,8 @@ describe('Message Interception', () => { // Verify message was queued const msg = await queue!.dequeue('new-task'); - expect(msg).toBeDefined(); - expect(msg?.message.method).toBe('test'); + assertQueuedNotification(msg); + expect(msg.message.method).toBe('test'); }); it('should queue multiple messages for the same task', async () => { @@ -3836,10 +3888,10 @@ describe('Message Interception', () => { // Verify both messages were queued in order const msg1 = await queue!.dequeue('reuse-task'); const msg2 = await queue!.dequeue('reuse-task'); - expect(msg1).toBeDefined(); - expect(msg1?.message.method).toBe('test1'); - expect(msg2).toBeDefined(); - expect(msg2?.message.method).toBe('test2'); + assertQueuedNotification(msg1); + expect(msg1.message.method).toBe('test1'); + assertQueuedNotification(msg2); + expect(msg2.message.method).toBe('test2'); }); it('should queue messages for different tasks separately', async () => { @@ -3855,9 +3907,9 @@ describe('Message Interception', () => { // Verify messages are queued separately const msg1 = await queue!.dequeue('task-1'); const msg2 = await queue!.dequeue('task-2'); - expect(msg1).toBeDefined(); + assertQueuedNotification(msg1); expect(msg1?.message.method).toBe('test1'); - expect(msg2).toBeDefined(); + assertQueuedNotification(msg2); expect(msg2?.message.method).toBe('test2'); }); }); @@ -3881,8 +3933,9 @@ describe('Message Interception', () => { // Verify the metadata is preserved in the queued message expect(queuedMessage).toBeDefined(); - expect(queuedMessage!.message.params!._meta).toBeDefined(); - expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual(relatedTask); + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.params!._meta).toBeDefined(); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual(relatedTask); }); it('should preserve relatedTask metadata in queued request', async () => { @@ -3905,8 +3958,9 @@ describe('Message Interception', () => { // Verify the metadata is preserved in the queued message expect(queuedMessage).toBeDefined(); - expect(queuedMessage!.message.params!._meta).toBeDefined(); - expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual(relatedTask); + assertQueuedRequest(queuedMessage); + expect(queuedMessage.message.params!._meta).toBeDefined(); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual(relatedTask); // Clean up transport.onmessage?.({ @@ -3941,9 +3995,10 @@ describe('Message Interception', () => { // Verify both existing and new metadata are preserved expect(queuedMessage).toBeDefined(); - expect(queuedMessage!.message.params!._meta!.customField).toBe('customValue'); - expect(queuedMessage!.message.params!._meta!.anotherField).toBe(123); - expect(queuedMessage!.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.params!._meta!.customField).toBe('customValue'); + expect(queuedMessage.message.params!._meta!.anotherField).toBe(123); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: 'task-preserve-meta' }); }); @@ -4662,3 +4717,686 @@ describe('requestStream() method', () => { }); }); }); + +describe('Error handling for missing resolvers', () => { + let protocol: Protocol; + let transport: MockTransport; + let taskStore: TaskStore & { [K in keyof TaskStore]: MockInstance }; + let taskMessageQueue: TaskMessageQueue; + let errorHandler: MockInstance; + + beforeEach(() => { + taskStore = createMockTaskStore(); + taskMessageQueue = new InMemoryTaskMessageQueue(); + errorHandler = vi.fn(); + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ + taskStore, + taskMessageQueue, + defaultTaskPollInterval: 100 + }); + + // @ts-expect-error deliberately overriding error handler with mock + protocol.onerror = errorHandler; + transport = new MockTransport(); + }); + + describe('Response routing with missing resolvers', () => { + it('should log error for unknown request ID without throwing', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + // Enqueue a response message without a corresponding resolver + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: 999, // Non-existent request ID + result: { content: [] } + }, + timestamp: Date.now() + }); + + // Set up the GetTaskPayloadRequest handler to process the message + const testProtocol = protocol as unknown as TestProtocol; + + // Simulate dequeuing and processing the response + const queuedMessage = await taskMessageQueue.dequeue(task.taskId); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.type).toBe('response'); + + // Manually trigger the response handling logic + if (queuedMessage && queuedMessage.type === 'response') { + const responseMessage = queuedMessage.message as JSONRPCResponse; + const requestId = responseMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + + if (!resolver) { + // This simulates what happens in the actual handler + protocol.onerror?.(new Error(`Response handler missing for request ${requestId}`)); + } + } + + // Verify error was logged + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('Response handler missing for request 999') + }) + ); + }); + + it('should continue processing after missing resolver error', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + // Enqueue a response with missing resolver, then a valid notification + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: 999, + result: { content: [] } + }, + timestamp: Date.now() + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'notification', + message: { + jsonrpc: '2.0', + method: 'notifications/progress', + params: { progress: 50, total: 100 } + }, + timestamp: Date.now() + }); + + // Process first message (response with missing resolver) + const msg1 = await taskMessageQueue.dequeue(task.taskId); + expect(msg1?.type).toBe('response'); + + // Process second message (should work fine) + const msg2 = await taskMessageQueue.dequeue(task.taskId); + expect(msg2?.type).toBe('notification'); + expect(msg2?.message).toMatchObject({ + method: 'notifications/progress' + }); + }); + }); + + describe('Task cancellation with missing resolvers', () => { + it('should log error when resolver is missing during cleanup', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + // Enqueue a request without storing a resolver + await taskMessageQueue.enqueue(task.taskId, { + type: 'request', + message: { + jsonrpc: '2.0', + id: 42, + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + timestamp: Date.now() + }); + + // Clear the task queue (simulating cancellation) + const testProtocol = protocol as unknown as TestProtocol; + await testProtocol._clearTaskQueue(task.taskId); + + // Verify error was logged for missing resolver + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('Resolver missing for request 42') + }) + ); + }); + + it('should handle cleanup gracefully when resolver exists', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + const requestId = 42; + const resolverMock = vi.fn(); + + // Store a resolver + const testProtocol = protocol as unknown as TestProtocol; + testProtocol._requestResolvers.set(requestId, resolverMock); + + // Enqueue a request + await taskMessageQueue.enqueue(task.taskId, { + type: 'request', + message: { + jsonrpc: '2.0', + id: requestId, + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + timestamp: Date.now() + }); + + // Clear the task queue + await testProtocol._clearTaskQueue(task.taskId); + + // Verify resolver was called with cancellation error + expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + + // Verify the error has the correct properties + const calledError = resolverMock.mock.calls[0][0]; + expect(calledError.code).toBe(ErrorCode.InternalError); + expect(calledError.message).toContain('Task cancelled or completed'); + + // Verify resolver was removed + expect(testProtocol._requestResolvers.has(requestId)).toBe(false); + }); + + it('should handle mixed messages during cleanup', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + const testProtocol = protocol as unknown as TestProtocol; + + // Enqueue multiple messages: request with resolver, request without, notification + const requestId1 = 42; + const resolverMock = vi.fn(); + testProtocol._requestResolvers.set(requestId1, resolverMock); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'request', + message: { + jsonrpc: '2.0', + id: requestId1, + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + timestamp: Date.now() + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'request', + message: { + jsonrpc: '2.0', + id: 43, // No resolver for this one + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + timestamp: Date.now() + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'notification', + message: { + jsonrpc: '2.0', + method: 'notifications/progress', + params: { progress: 50, total: 100 } + }, + timestamp: Date.now() + }); + + // Clear the task queue + await testProtocol._clearTaskQueue(task.taskId); + + // Verify resolver was called for first request + expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + + // Verify the error has the correct properties + const calledError = resolverMock.mock.calls[0][0]; + expect(calledError.code).toBe(ErrorCode.InternalError); + expect(calledError.message).toContain('Task cancelled or completed'); + + // Verify error was logged for second request + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('Resolver missing for request 43') + }) + ); + + // Verify queue is empty + const remaining = await taskMessageQueue.dequeue(task.taskId); + expect(remaining).toBeUndefined(); + }); + }); + + describe('Side-channeled request error handling', () => { + it('should log error when response handler is missing for side-channeled request', async () => { + await protocol.connect(transport); + + const testProtocol = protocol as unknown as TestProtocol; + const messageId = 123; + + // Create a response resolver without a corresponding response handler + const responseResolver = (response: JSONRPCResponse | Error) => { + const handler = testProtocol._responseHandlers.get(messageId); + if (handler) { + handler(response); + } else { + protocol.onerror?.(new Error(`Response handler missing for side-channeled request ${messageId}`)); + } + }; + + // Simulate the resolver being called without a handler + const mockResponse: JSONRPCResponse = { + jsonrpc: '2.0', + id: messageId, + result: { content: [] } + }; + + responseResolver(mockResponse); + + // Verify error was logged + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('Response handler missing for side-channeled request 123') + }) + ); + }); + }); + + describe('Error handling does not throw exceptions', () => { + it('should not throw when processing response with missing resolver', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: 999, + result: { content: [] } + }, + timestamp: Date.now() + }); + + // This should not throw + const processMessage = async () => { + const msg = await taskMessageQueue.dequeue(task.taskId); + if (msg && msg.type === 'response') { + const testProtocol = protocol as unknown as TestProtocol; + const responseMessage = msg.message as JSONRPCResponse; + const requestId = responseMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + if (!resolver) { + protocol.onerror?.(new Error(`Response handler missing for request ${requestId}`)); + } + } + }; + + await expect(processMessage()).resolves.not.toThrow(); + }); + + it('should not throw during task cleanup with missing resolvers', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'request', + message: { + jsonrpc: '2.0', + id: 42, + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + timestamp: Date.now() + }); + + const testProtocol = protocol as unknown as TestProtocol; + + // This should not throw + await expect(testProtocol._clearTaskQueue(task.taskId)).resolves.not.toThrow(); + }); + }); + + describe('Error message routing', () => { + it('should route error messages to resolvers correctly', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const requestId = 42; + const resolverMock = vi.fn(); + + // Store a resolver + const testProtocol = protocol as unknown as TestProtocol; + testProtocol._requestResolvers.set(requestId, resolverMock); + + // Enqueue an error message + await taskMessageQueue.enqueue(task.taskId, { + type: 'error', + message: { + jsonrpc: '2.0', + id: requestId, + error: { + code: ErrorCode.InvalidRequest, + message: 'Invalid request parameters' + } + }, + timestamp: Date.now() + }); + + // Simulate dequeuing and processing the error + const queuedMessage = await taskMessageQueue.dequeue(task.taskId); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.type).toBe('error'); + + // Manually trigger the error handling logic + if (queuedMessage && queuedMessage.type === 'error') { + const errorMessage = queuedMessage.message as JSONRPCError; + const reqId = errorMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(reqId); + + if (resolver) { + testProtocol._requestResolvers.delete(reqId); + const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + resolver(error); + } + } + + // Verify resolver was called with McpError + expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + const calledError = resolverMock.mock.calls[0][0]; + expect(calledError.code).toBe(ErrorCode.InvalidRequest); + expect(calledError.message).toContain('Invalid request parameters'); + + // Verify resolver was removed from map + expect(testProtocol._requestResolvers.has(requestId)).toBe(false); + }); + + it('should log error for unknown request ID in error messages', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + // Enqueue an error message without a corresponding resolver + await taskMessageQueue.enqueue(task.taskId, { + type: 'error', + message: { + jsonrpc: '2.0', + id: 999, + error: { + code: ErrorCode.InternalError, + message: 'Something went wrong' + } + }, + timestamp: Date.now() + }); + + // Simulate dequeuing and processing the error + const queuedMessage = await taskMessageQueue.dequeue(task.taskId); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.type).toBe('error'); + + // Manually trigger the error handling logic + if (queuedMessage && queuedMessage.type === 'error') { + const testProtocol = protocol as unknown as TestProtocol; + const errorMessage = queuedMessage.message as JSONRPCError; + const requestId = errorMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + + if (!resolver) { + protocol.onerror?.(new Error(`Error handler missing for request ${requestId}`)); + } + } + + // Verify error was logged + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('Error handler missing for request 999') + }) + ); + }); + + it('should handle error messages with data field', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const requestId = 42; + const resolverMock = vi.fn(); + + // Store a resolver + const testProtocol = protocol as unknown as TestProtocol; + testProtocol._requestResolvers.set(requestId, resolverMock); + + // Enqueue an error message with data field + await taskMessageQueue.enqueue(task.taskId, { + type: 'error', + message: { + jsonrpc: '2.0', + id: requestId, + error: { + code: ErrorCode.InvalidParams, + message: 'Validation failed', + data: { field: 'userName', reason: 'required' } + } + }, + timestamp: Date.now() + }); + + // Simulate dequeuing and processing the error + const queuedMessage = await taskMessageQueue.dequeue(task.taskId); + + if (queuedMessage && queuedMessage.type === 'error') { + const errorMessage = queuedMessage.message as JSONRPCError; + const reqId = errorMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(reqId); + + if (resolver) { + testProtocol._requestResolvers.delete(reqId); + const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + resolver(error); + } + } + + // Verify resolver was called with McpError including data + expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + const calledError = resolverMock.mock.calls[0][0]; + expect(calledError.code).toBe(ErrorCode.InvalidParams); + expect(calledError.message).toContain('Validation failed'); + expect(calledError.data).toEqual({ field: 'userName', reason: 'required' }); + }); + + it('should not throw when processing error with missing resolver', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'error', + message: { + jsonrpc: '2.0', + id: 999, + error: { + code: ErrorCode.InternalError, + message: 'Error occurred' + } + }, + timestamp: Date.now() + }); + + // This should not throw + const processMessage = async () => { + const msg = await taskMessageQueue.dequeue(task.taskId); + if (msg && msg.type === 'error') { + const testProtocol = protocol as unknown as TestProtocol; + const errorMessage = msg.message as JSONRPCError; + const requestId = errorMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + if (!resolver) { + protocol.onerror?.(new Error(`Error handler missing for request ${requestId}`)); + } + } + }; + + await expect(processMessage()).resolves.not.toThrow(); + }); + }); + + describe('Response and error message routing integration', () => { + it('should handle mixed response and error messages in queue', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const testProtocol = protocol as unknown as TestProtocol; + + // Set up resolvers for multiple requests + const resolver1 = vi.fn(); + const resolver2 = vi.fn(); + const resolver3 = vi.fn(); + + testProtocol._requestResolvers.set(1, resolver1); + testProtocol._requestResolvers.set(2, resolver2); + testProtocol._requestResolvers.set(3, resolver3); + + // Enqueue mixed messages: response, error, response + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: 1, + result: { content: [{ type: 'text', text: 'Success' }] } + }, + timestamp: Date.now() + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'error', + message: { + jsonrpc: '2.0', + id: 2, + error: { + code: ErrorCode.InvalidRequest, + message: 'Request failed' + } + }, + timestamp: Date.now() + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: 3, + result: { content: [{ type: 'text', text: 'Another success' }] } + }, + timestamp: Date.now() + }); + + // Process all messages + let msg; + while ((msg = await taskMessageQueue.dequeue(task.taskId))) { + if (msg.type === 'response') { + const responseMessage = msg.message as JSONRPCResponse; + const requestId = responseMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + if (resolver) { + testProtocol._requestResolvers.delete(requestId); + resolver(responseMessage); + } + } else if (msg.type === 'error') { + const errorMessage = msg.message as JSONRPCError; + const requestId = errorMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + if (resolver) { + testProtocol._requestResolvers.delete(requestId); + const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + resolver(error); + } + } + } + + // Verify all resolvers were called correctly + expect(resolver1).toHaveBeenCalledWith(expect.objectContaining({ id: 1 })); + expect(resolver2).toHaveBeenCalledWith(expect.any(McpError)); + expect(resolver3).toHaveBeenCalledWith(expect.objectContaining({ id: 3 })); + + // Verify error has correct properties + const error = resolver2.mock.calls[0][0]; + expect(error.code).toBe(ErrorCode.InvalidRequest); + expect(error.message).toContain('Request failed'); + + // Verify all resolvers were removed + expect(testProtocol._requestResolvers.size).toBe(0); + }); + + it('should maintain FIFO order when processing responses and errors', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const testProtocol = protocol as unknown as TestProtocol; + + const callOrder: number[] = []; + const resolver1 = vi.fn(() => callOrder.push(1)); + const resolver2 = vi.fn(() => callOrder.push(2)); + const resolver3 = vi.fn(() => callOrder.push(3)); + + testProtocol._requestResolvers.set(1, resolver1); + testProtocol._requestResolvers.set(2, resolver2); + testProtocol._requestResolvers.set(3, resolver3); + + // Enqueue in specific order + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { jsonrpc: '2.0', id: 1, result: {} }, + timestamp: 1000 + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'error', + message: { + jsonrpc: '2.0', + id: 2, + error: { code: -32600, message: 'Error' } + }, + timestamp: 2000 + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { jsonrpc: '2.0', id: 3, result: {} }, + timestamp: 3000 + }); + + // Process all messages + let msg; + while ((msg = await taskMessageQueue.dequeue(task.taskId))) { + if (msg.type === 'response') { + const responseMessage = msg.message as JSONRPCResponse; + const requestId = responseMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + if (resolver) { + testProtocol._requestResolvers.delete(requestId); + resolver(responseMessage); + } + } else if (msg.type === 'error') { + const errorMessage = msg.message as JSONRPCError; + const requestId = errorMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + if (resolver) { + testProtocol._requestResolvers.delete(requestId); + const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + resolver(error); + } + } + } + + // Verify FIFO order was maintained + expect(callOrder).toEqual([1, 2, 3]); + }); + }); +}); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index d24c6f395..50a85c867 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -389,6 +389,42 @@ export abstract class Protocol | undefined; - if (queuedMessage.type === 'request' && queuedMessage.responseResolver) { - // Capture in const to satisfy TypeScript's flow analysis - const msg = queuedMessage; - responsePromise = new Promise((resolve, reject) => { - const originalResolver = msg.responseResolver!; - const wrappedResolver = (response: JSONRPCResponse | Error) => { - // First, deliver the response to the task handler - originalResolver(response); - // Then, signal that we can continue delivering messages - if (response instanceof Error) { - reject(response); - } else { - resolve(); - } - }; - // Replace the resolver so _onresponse calls our wrapped version - if (msg.originalRequestId !== undefined) { - this._requestResolvers.set(msg.originalRequestId, wrappedResolver); - } - }); - } - // Send the message on the response stream by passing the relatedRequestId // This tells the transport to write the message to the tasks/result response stream await this._transport?.send(messageToSend, { relatedRequestId: extra.requestId }); - - // If it was a request, wait for the response before delivering the next message - if (responsePromise) { - await responsePromise; - } } } @@ -1128,9 +1134,7 @@ export abstract class Protocol { // Notify any waiting tasks/result calls @@ -1414,10 +1418,17 @@ export abstract class Protocol void; - /** For requests: the original request ID for response routing */ - originalRequestId?: RequestId; +} + +export interface QueuedRequest extends BaseQueuedMessage { + type: 'request'; + /** The actual JSONRPC request */ + message: JSONRPCRequest; +} + +export interface QueuedNotification extends BaseQueuedMessage { + type: 'notification'; + /** The actual JSONRPC notification */ + message: JSONRPCNotification; +} + +export interface QueuedResponse extends BaseQueuedMessage { + type: 'response'; + /** The actual JSONRPC response */ + message: JSONRPCResponse; +} + +export interface QueuedError extends BaseQueuedMessage { + type: 'error'; + /** The actual JSONRPC error */ + message: JSONRPCError; } /** * Interface for managing per-task FIFO message queues. * * Similar to TaskStore, this allows pluggable queue implementations - * (in-memory, Redis, other distributed queues, etc.) for server-initiated - * messages that will be delivered through the tasks/result response stream. + * (in-memory, Redis, other distributed queues, etc.). * * Each method accepts taskId and optional sessionId parameters to enable * a single queue instance to manage messages for multiple tasks, with * isolation based on task ID and session ID. * * All methods are async to support external storage implementations. - * - * Performance Notes: - * - enqueue() atomically enforces maxSize to prevent race conditions - * - dequeue() returns undefined when empty, eliminating need for isEmpty() checks - * - dequeueAll() is used when tasks are cancelled/failed to reject pending resolvers + * All data in QueuedMessage must be JSON-serializable. */ export interface TaskMessageQueue { /** @@ -56,7 +84,7 @@ export interface TaskMessageQueue { /** * Removes and returns all messages from the queue for a specific task. - * Used when tasks are cancelled or failed to reject any pending request resolvers. + * Used when tasks are cancelled or failed to clean up pending messages. * @param taskId The task identifier * @param sessionId Optional session ID for binding the query to a specific session * @returns Array of all messages that were in the queue From 8cf6675b8d291e2ccd3a690492d95a5b19be2c00 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Fri, 21 Nov 2025 15:13:00 -0800 Subject: [PATCH 61/84] Move annotations.taskHint to execution.taskSupport --- src/client/index.ts | 4 +- src/server/mcp.test.ts | 270 ++++++++++++----------------------------- src/server/mcp.ts | 61 +++++++--- src/types.ts | 20 ++- 4 files changed, 140 insertions(+), 215 deletions(-) diff --git a/src/client/index.ts b/src/client/index.ts index 4d2d1623d..f59449b7b 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -745,8 +745,8 @@ export class Client< } // If the tool supports task-based execution, cache that information - const taskHint = tool.annotations?.taskHint; - if (taskHint === 'always' || taskHint === 'optional') { + const taskSupport = tool.execution?.taskSupport; + if (taskSupport === 'required' || taskSupport === 'optional') { this._cachedKnownTaskTools.add(tool.name); } } diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 3a382a534..82d379ba7 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -23,7 +23,6 @@ import { completable } from './completable.js'; import { McpServer, ResourceTemplate } from './mcp.js'; import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; import { zodTestMatrix, type ZodMatrixEntry } from '../shared/zodTestMatrix.js'; -import * as z4 from 'zod/v4'; function createLatch() { let latch = false; @@ -5695,7 +5694,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); describe('Tool-level task hints with automatic polling wrapper', () => { - test('should return error for tool with taskHint "always" called without task augmentation', async () => { + test('should return error for tool with taskSupport "required" called without task augmentation', async () => { const taskStore = new InMemoryTaskStore(); const mcpServer = new McpServer( @@ -5736,7 +5735,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } ); - // Register a task-based tool with taskHint "always" BEFORE connecting + // Register a task-based tool with taskSupport "required" mcpServer.registerToolTask( 'long-running-task', { @@ -5744,8 +5743,8 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { inputSchema: { input: z.string() }, - annotations: { - taskHint: 'always' as unknown as 'never' // override to allow violating build-time constraints + execution: { + taskSupport: 'required' } }, { @@ -5802,7 +5801,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { taskStore.cleanup(); }); - test('should automatically poll and return CallToolResult for tool with taskHint "optional" called without task augmentation', async () => { + test('should automatically poll and return CallToolResult for tool with taskSupport "optional" called without task augmentation', async () => { const taskStore = new InMemoryTaskStore(); const { releaseLatch, waitForLatch } = createLatch(); @@ -5844,7 +5843,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } ); - // Register a task-based tool with taskHint "optional" BEFORE connecting + // Register a task-based tool with taskSupport "optional" mcpServer.registerToolTask( 'optional-task', { @@ -5852,8 +5851,8 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { inputSchema: { value: z.number() }, - annotations: { - taskHint: 'optional' as unknown as 'never' // override to allow violating build-time constraints + execution: { + taskSupport: 'optional' } }, { @@ -5913,7 +5912,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { taskStore.cleanup(); }); - test('should return CreateTaskResult when tool with taskHint "always" is called WITH task augmentation', async () => { + test('should return CreateTaskResult when tool with taskSupport "required" is called WITH task augmentation', async () => { const taskStore = new InMemoryTaskStore(); const { releaseLatch, waitForLatch } = createLatch(); @@ -5955,7 +5954,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } ); - // Register a task-based tool with taskHint "always" BEFORE connecting + // Register a task-based tool with taskSupport "required" mcpServer.registerToolTask( 'task-tool', { @@ -5963,8 +5962,8 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { inputSchema: { data: z.string() }, - annotations: { - taskHint: 'always' as unknown as 'never' // override to allow violating build-time constraints + execution: { + taskSupport: 'required' } }, { @@ -6036,177 +6035,6 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { taskStore.cleanup(); }); - test('should throw error if tool with taskHint "always" is not registered with registerToolTask', async () => { - const mcpServer = new McpServer({ - name: 'test server', - version: '1.0' - }); - - const client = new Client({ - name: 'test client', - version: '1.0' - }); - - // Register a regular tool with taskHint "always" (incorrect usage) BEFORE connecting - mcpServer.registerTool( - 'bad-tool', - { - description: 'A tool with incorrect taskHint', - annotations: { - taskHint: 'always' - } - }, - async () => ({ - content: [{ type: 'text' as const, text: 'Should not work' }] - }) - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - - // Call the tool - should return error result - const result = await client.callTool( - { - name: 'bad-tool', - arguments: {} - }, - CallToolResultSchema - ); - - expect(result.isError).toBe(true); - const content = result.content as TextContent[]; - expect(content[0].text).toContain("has taskHint 'always' but was not registered with registerToolTask"); - }); - - test('should throw error if tool with taskHint "optional" is not registered with registerToolTask', async () => { - const mcpServer = new McpServer({ - name: 'test server', - version: '1.0' - }); - - const client = new Client({ - name: 'test client', - version: '1.0' - }); - - // Register a regular tool with taskHint "optional" (incorrect usage) BEFORE connecting - mcpServer.registerTool( - 'bad-optional-tool', - { - description: 'A tool with incorrect taskHint', - annotations: { - taskHint: 'optional' - } - }, - async () => ({ - content: [{ type: 'text' as const, text: 'Should not work' }] - }) - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - - // Call the tool - should return error result - const result = await client.callTool( - { - name: 'bad-optional-tool', - arguments: {} - }, - CallToolResultSchema - ); - - expect(result.isError).toBe(true); - const content = result.content as TextContent[]; - expect(content[0].text).toContain("has taskHint 'optional' but was not registered with registerToolTask"); - }); - - test('should work normally for tool with taskHint "never"', async () => { - const mcpServer = new McpServer({ - name: 'test server', - version: '1.0' - }); - - const client = new Client({ - name: 'test client', - version: '1.0' - }); - - // Register a regular tool with taskHint "never" BEFORE connecting - mcpServer.registerTool( - 'normal-tool', - { - description: 'A normal tool', - inputSchema: { - message: z.string() - }, - annotations: { - taskHint: 'never' - } - }, - async ({ message }) => ({ - content: [{ type: 'text' as const, text: `Echo: ${message}` }] - }) - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - - // Call the tool - should work normally - const result = await client.callTool( - { - name: 'normal-tool', - arguments: { message: 'hello' } - }, - CallToolResultSchema - ); - - expect(result.content).toEqual([{ type: 'text' as const, text: 'Echo: hello' }]); - }); - - test('should work normally for tool without taskHint', async () => { - const mcpServer = new McpServer({ - name: 'test server', - version: '1.0' - }); - - const client = new Client({ - name: 'test client', - version: '1.0' - }); - - // Register a regular tool without taskHint BEFORE connecting - mcpServer.registerTool( - 'simple-tool', - { - description: 'A simple tool', - inputSchema: { - value: z4.number() - } - }, - async ({ value }) => ({ - content: [{ type: 'text' as const, text: `Value: ${value}` }] - }) - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - - // Call the tool - should work normally - const result = await client.callTool( - { - name: 'simple-tool', - arguments: { value: 42 } - }, - CallToolResultSchema - ); - - expect(result.content).toEqual([{ type: 'text' as const, text: 'Value: 42' }]); - }); - test('should handle task failures during automatic polling', async () => { const taskStore = new InMemoryTaskStore(); const { releaseLatch, waitForLatch } = createLatch(); @@ -6249,13 +6077,13 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } ); - // Register a task-based tool that fails BEFORE connecting + // Register a task-based tool that fails mcpServer.registerToolTask( 'failing-task', { description: 'A failing task', - annotations: { - taskHint: 'optional' as unknown as 'never' // override to allow violating build-time constraints + execution: { + taskSupport: 'optional' } }, { @@ -6358,13 +6186,13 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } ); - // Register a task-based tool that gets cancelled BEFORE connecting + // Register a task-based tool that gets cancelled mcpServer.registerToolTask( 'cancelled-task', { description: 'A task that gets cancelled', - annotations: { - taskHint: 'optional' as unknown as 'never' // override to allow violating build-time constraints + execution: { + taskSupport: 'optional' } }, { @@ -6420,5 +6248,67 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { await waitForLatch(); taskStore.cleanup(); }); + + test('should raise error when registerToolTask is called with taskSupport "forbidden"', () => { + const taskStore = new InMemoryTaskStore(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + // Attempt to register a task-based tool with taskSupport "forbidden" (cast to bypass type checking) + expect(() => { + mcpServer.registerToolTask( + 'invalid-task', + { + description: 'A task with forbidden support', + inputSchema: { + input: z.string() + }, + execution: { + taskSupport: 'forbidden' as unknown as 'required' + } + }, + { + createTask: async ({ input }, extra) => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }, extra.requestId, { + method: 'tools/call', + params: { name: 'invalid-task', arguments: { input } } + }); + return { task }; + }, + getTask: async (_args, extra) => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error('Task not found'); + } + return task; + }, + getTaskResult: async (_input, extra) => { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as CallToolResult; + } + } + ); + }).toThrow(); + + taskStore.cleanup(); + }); }); }); diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 47a6f8cd9..11707dc14 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -52,7 +52,8 @@ import { CompleteRequestResourceTemplate, assertCompleteRequestPrompt, assertCompleteRequestResourceTemplate, - CallToolRequest + CallToolRequest, + ToolExecution } from '../types.js'; import { isCompletable, getCompleter } from './completable.js'; import { UriTemplate, Variables } from '../shared/uriTemplate.js'; @@ -164,27 +165,27 @@ export class McpServer { } const isTaskRequest = !!request.params.task; - const taskHint = tool.annotations?.taskHint; + const taskSupport = tool.execution?.taskSupport; const isTaskHandler = 'createTask' in (tool.handler as AnyToolHandler); // Validate task hint configuration - if ((taskHint === 'always' || taskHint === 'optional') && !isTaskHandler) { + if ((taskSupport === 'required' || taskSupport === 'optional') && !isTaskHandler) { throw new McpError( ErrorCode.InternalError, - `Tool ${request.params.name} has taskHint '${taskHint}' but was not registered with registerToolTask` + `Tool ${request.params.name} has taskSupport '${taskSupport}' but was not registered with registerToolTask` ); } - // Handle taskHint 'always' without task augmentation - if (taskHint === 'always' && !isTaskRequest) { + // Handle taskSupport 'required' without task augmentation + if (taskSupport === 'required' && !isTaskRequest) { throw new McpError( ErrorCode.MethodNotFound, - `Tool ${request.params.name} requires task augmentation (taskHint: 'always')` + `Tool ${request.params.name} requires task augmentation (taskSupport: 'required')` ); } - // Handle taskHint 'optional' without task augmentation - automatic polling - if (taskHint === 'optional' && !isTaskRequest && isTaskHandler) { + // Handle taskSupport 'optional' without task augmentation - automatic polling + if (taskSupport === 'optional' && !isTaskRequest && isTaskHandler) { return await this.handleAutomaticTaskPolling(tool, request, extra); } @@ -337,7 +338,7 @@ export class McpServer { } /** - * Handles automatic task polling for tools with taskHint 'optional'. + * Handles automatic task polling for tools with taskSupport 'optional'. */ private async handleAutomaticTaskPolling( tool: RegisteredTool, @@ -838,6 +839,7 @@ export class McpServer { inputSchema: ZodRawShapeCompat | AnySchema | undefined, outputSchema: ZodRawShapeCompat | AnySchema | undefined, annotations: ToolAnnotations | undefined, + execution: ToolExecution | undefined, _meta: Record | undefined, handler: AnyToolHandler ): RegisteredTool { @@ -850,6 +852,7 @@ export class McpServer { inputSchema: getZodSchemaObject(inputSchema), outputSchema: getZodSchemaObject(outputSchema), annotations, + execution, _meta, handler: handler, enabled: true, @@ -992,7 +995,17 @@ export class McpServer { } const callback = rest[0] as ToolCallback; - return this._createRegisteredTool(name, undefined, description, inputSchema, outputSchema, annotations, undefined, callback); + return this._createRegisteredTool( + name, + undefined, + description, + inputSchema, + outputSchema, + annotations, + { taskSupport: 'forbidden' }, + undefined, + callback + ); } /** @@ -1023,6 +1036,7 @@ export class McpServer { inputSchema, outputSchema, annotations, + { taskSupport: 'forbidden' }, _meta, cb as ToolCallback ); @@ -1037,7 +1051,8 @@ export class McpServer { title?: string; description?: string; outputSchema?: OutputArgs; - annotations?: NoTaskToolAnnotations; + annotations?: ToolAnnotations; + execution?: TaskToolExecution; _meta?: Record; }, handler: ToolTaskHandler @@ -1053,7 +1068,8 @@ export class McpServer { description?: string; inputSchema: InputArgs; outputSchema?: OutputArgs; - annotations?: NoTaskToolAnnotations; + annotations?: ToolAnnotations; + execution?: TaskToolExecution; _meta?: Record; }, handler: ToolTaskHandler @@ -1069,18 +1085,26 @@ export class McpServer { description?: string; inputSchema?: InputArgs; outputSchema?: OutputArgs; - annotations?: NoTaskToolAnnotations; + annotations?: ToolAnnotations; + execution?: TaskToolExecution; _meta?: Record; }, handler: ToolTaskHandler ): RegisteredTool { + // Validate that taskSupport is not 'forbidden' for task-based tools + const execution: ToolExecution = { taskSupport: 'required', ...config.execution }; + if (execution.taskSupport === 'forbidden') { + throw new Error(`Cannot register task-based tool '${name}' with taskSupport 'forbidden'. Use registerTool() instead.`); + } + return this._createRegisteredTool( name, config.title, config.description, config.inputSchema, config.outputSchema, - { taskHint: 'always', ...config.annotations }, + config.annotations, + execution, config._meta, handler ); @@ -1338,9 +1362,9 @@ export type AnyToolCallback = ToolCallback | ToolTaskHandler; -export interface NoTaskToolAnnotations extends ToolAnnotations { - taskHint?: 'never'; -} +export type TaskToolExecution = Omit & { + taskSupport: TaskSupport extends 'forbidden' | undefined ? never : TaskSupport; +}; export type RegisteredTool = { title?: string; @@ -1348,6 +1372,7 @@ export type RegisteredTool = { inputSchema?: AnySchema; outputSchema?: AnySchema; annotations?: ToolAnnotations; + execution?: ToolExecution; _meta?: Record; handler: AnyToolHandler; enabled: boolean; diff --git a/src/types.ts b/src/types.ts index 47b077319..843014b46 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1224,17 +1224,22 @@ export const ToolAnnotationsSchema = z.object({ * * Default: true */ - openWorldHint: z.boolean().optional(), + openWorldHint: z.boolean().optional() +}); +/** + * Execution-related properties for a tool. + */ +export const ToolExecutionSchema = z.object({ /** * Indicates the tool's preference for task-augmented execution. - * - "always": Clients SHOULD invoke the tool as a task + * - "required": Clients MUST invoke the tool as a task * - "optional": Clients MAY invoke the tool as a task or normal request - * - "never": Clients SHALL NOT attempt to invoke the tool as a task + * - "forbidden": Clients MUST NOT attempt to invoke the tool as a task * - * If not present, defaults to "never". + * If not present, defaults to "forbidden". */ - taskHint: z.enum(['always', 'optional', 'never']).optional() + taskSupport: z.enum(['required', 'optional', 'forbidden']).optional() }); /** @@ -1275,6 +1280,10 @@ export const ToolSchema = z.object({ * Optional additional tool information. */ annotations: z.optional(ToolAnnotationsSchema), + /** + * Execution-related properties for this tool. + */ + execution: z.optional(ToolExecutionSchema), /** * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) @@ -2235,6 +2244,7 @@ export type PromptListChangedNotification = Infer; +export type ToolExecution = Infer; export type Tool = Infer; export type ListToolsRequest = Infer; export type ListToolsResult = Infer; From 8acb942bd31cb53afeb7c417ba441dc432f4d9fc Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Fri, 21 Nov 2025 16:00:25 -0800 Subject: [PATCH 62/84] Remove now-unneeded _taskResultWaiters map --- src/shared/protocol.test.ts | 101 ++---------------------------------- src/shared/protocol.ts | 62 ++++------------------ 2 files changed, 13 insertions(+), 150 deletions(-) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 042baea7e..bb3df7987 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -25,7 +25,6 @@ import { InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.j // Type helper for accessing private Protocol properties in tests interface TestProtocol { _taskMessageQueue?: TaskMessageQueue; - _taskResultWaiters: Map void>>; _requestResolvers: Map void>; _responseHandlers: Map void>; _taskProgressTokens: Map; @@ -3074,46 +3073,8 @@ describe('Message interception for task-related notifications', () => { // notifications without relatedTask metadata are sent via transport, not queued }); - it('should notify task result waiters after queuing', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - - await server.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); - - // Set up a waiter - let waiterCalled = false; - const waiters = (server as unknown as TestProtocol)._taskResultWaiters; - waiters.set(task.taskId, [ - () => { - waiterCalled = true; - } - ]); - - // Send a notification with related task metadata - await server.notification( - { - method: 'notifications/message', - params: { level: 'info', data: 'test message' } - }, - { - relatedTask: { taskId: task.taskId } - } - ); - - // Verify the waiter was called - expect(waiterCalled).toBe(true); - expect(waiters.has(task.taskId)).toBe(false); // Waiters should be cleared - }); + // Test removed: _taskResultWaiters was removed in favor of polling-based task updates + // The functionality is still tested through integration tests that verify message queuing works it('should handle queue overflow by failing the task', async () => { const taskStore = createMockTaskStore(); @@ -3324,62 +3285,8 @@ describe('Message interception for task-related requests', () => { await requestPromise; }); - it('should notify task result waiters after queuing request', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - - await server.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); - - // Set up a waiter - let waiterCalled = false; - const waiters = (server as unknown as TestProtocol)._taskResultWaiters; - waiters.set(task.taskId, [ - () => { - waiterCalled = true; - } - ]); - - // Send a request with related task metadata - const requestPromise = server.request( - { - method: 'ping', - params: {} - }, - z.object({}), - { - relatedTask: { taskId: task.taskId } - } - ); - - // Wait for the request to be queued and waiter to be called - await new Promise(resolve => setTimeout(resolve, 10)); - - // Verify the waiter was called - expect(waiterCalled).toBe(true); - expect(waiters.has(task.taskId)).toBe(false); // Waiters should be cleared - - // Clean up - const queue = (server as unknown as TestProtocol)._taskMessageQueue; - const queuedMessage = await queue!.dequeue(task.taskId); - const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; - transport.onmessage?.({ - jsonrpc: '2.0', - id: requestId, - result: {} - }); - - await requestPromise; - }); + // Test removed: _taskResultWaiters was removed in favor of polling-based task updates + // The functionality is still tested through integration tests that verify message queuing works it('should store request resolver for response routing', async () => { const taskStore = createMockTaskStore(); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 50a85c867..2735d7972 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -320,8 +320,6 @@ export abstract class Protocol void>> = new Map(); private _requestResolvers: Map void> = new Map(); /** @@ -1135,15 +1133,10 @@ export abstract class Protocol { - // Notify any waiting tasks/result calls - this._notifyTaskResultWaiters(relatedTaskId); - }) - .catch(error => { - this._cleanupTimeout(messageId); - reject(error); - }); + }).catch(error => { + this._cleanupTimeout(messageId); + reject(error); + }); // Don't send through transport - queued messages are delivered via tasks/result only // This prevents duplicate delivery for bidirectional transports @@ -1225,9 +1218,6 @@ export abstract class Protocol { + clearTimeout(timeoutId); reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); }, { once: true } ); - - // Polling as a fallback mechanism for edge cases and race conditions - // Most updates will be handled by event-driven notifications via _notifyTaskResultWaiters() - // We trigger notification on every poll - the handler will check if work is available - const pollInterval = setInterval(async () => { - try { - this._notifyTaskResultWaiters(taskId); - } catch { - // Ignore errors during polling - } - }, interval); - - // Clean up the interval when the promise resolves or rejects - const cleanup = () => clearInterval(pollInterval); - signal.addEventListener('abort', cleanup, { once: true }); }); } From e7143f1173e3024e458a832c3f92dffad4906ad2 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Fri, 21 Nov 2025 16:12:27 -0800 Subject: [PATCH 63/84] Don't fail the task or dump the queue on a failed enqueue operation --- src/shared/protocol.test.ts | 24 +++++++++++++----------- src/shared/protocol.ts | 24 ++++++++++-------------- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index bb3df7987..9df10e484 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -3076,7 +3076,7 @@ describe('Message interception for task-related notifications', () => { // Test removed: _taskResultWaiters was removed in favor of polling-based task updates // The functionality is still tested through integration tests that verify message queuing works - it('should handle queue overflow by failing the task', async () => { + it('should propagate queue overflow errors without failing the task', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); const server = new (class extends Protocol { @@ -3105,7 +3105,7 @@ describe('Message interception for task-related notifications', () => { ); } - // Try to add one more message - should throw and fail the task + // Try to add one more message - should throw an error await expect( server.notification( { @@ -3116,10 +3116,11 @@ describe('Message interception for task-related notifications', () => { relatedTask: { taskId: task.taskId } } ) - ).rejects.toThrow(McpError); + ).rejects.toThrow('overflow'); - // Verify the task was failed with overflow error - expect(taskStore.updateTaskStatus).toHaveBeenCalledWith(task.taskId, 'failed', expect.stringContaining('overflow'), undefined); + // Verify the task was NOT automatically failed by the Protocol + // (implementations can choose to fail tasks on overflow if they want) + expect(taskStore.updateTaskStatus).not.toHaveBeenCalledWith(task.taskId, 'failed', expect.anything(), expect.anything()); }); it('should extract task ID correctly from metadata', async () => { @@ -3482,7 +3483,7 @@ describe('Message interception for task-related requests', () => { expect(errors.some(e => e.message.includes('Response handler missing for request'))).toBe(true); }); - it('should handle queue overflow for requests', async () => { + it('should propagate queue overflow errors for requests without failing the task', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); const server = new (class extends Protocol { @@ -3513,12 +3514,12 @@ describe('Message interception for task-related requests', () => { } ) .catch(() => { - // Expected to reject when queue is cleared + // Requests will remain pending until task completes or fails }); promises.push(promise); } - // Try to add one more request - should throw and fail the task + // Try to add one more request - should throw an error await expect( server.request( { @@ -3530,10 +3531,11 @@ describe('Message interception for task-related requests', () => { relatedTask: { taskId: task.taskId } } ) - ).rejects.toThrow(McpError); + ).rejects.toThrow('overflow'); - // Verify the task was failed with overflow error - expect(taskStore.updateTaskStatus).toHaveBeenCalledWith(task.taskId, 'failed', expect.stringContaining('overflow'), undefined); + // Verify the task was NOT automatically failed by the Protocol + // (implementations can choose to fail tasks on overflow if they want) + expect(taskStore.updateTaskStatus).not.toHaveBeenCalledWith(task.taskId, 'failed', expect.anything(), expect.anything()); }); }); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 2735d7972..4604737e9 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -91,7 +91,9 @@ export type ProtocolOptions = { /** * Maximum number of messages that can be queued per task for side-channel delivery. * If undefined, the queue size is unbounded. - * When the limit is exceeded, the task will be transitioned to failed status. + * When the limit is exceeded, the TaskMessageQueue implementation's enqueue() method + * will throw an error. It's the implementation's responsibility to handle overflow + * appropriately (e.g., by failing the task, dropping messages, etc.). */ maxTaskQueueSize?: number; }; @@ -1376,26 +1378,20 @@ export abstract class Protocol { // Task message queues are only used when taskStore is configured if (!this._taskStore || !this._taskMessageQueue) { - throw new McpError(ErrorCode.InternalError, 'Cannot enqueue task message: taskStore and taskMessageQueue are not configured'); + throw new Error('Cannot enqueue task message: taskStore and taskMessageQueue are not configured'); } const maxQueueSize = this._options?.maxTaskQueueSize; - - try { - await this._taskMessageQueue.enqueue(taskId, message, sessionId, maxQueueSize); - } catch (error) { - // Enqueue failed (e.g., queue overflow, storage error) - fail the task and clear the queue - this._onerror(error as Error); - const errorMessage = error instanceof Error ? error.message : 'Task message enqueue failed'; - this._taskStore.updateTaskStatus(taskId, 'failed', errorMessage, sessionId).catch(err => this._onerror(err)); - await this._clearTaskQueue(taskId, sessionId); - throw new McpError(ErrorCode.InternalError, `Failed to enqueue task message: ${errorMessage}`); - } + await this._taskMessageQueue.enqueue(taskId, message, sessionId, maxQueueSize); } /** From 5fcad03f7434e2258829abd55da5956fba6b92dd Mon Sep 17 00:00:00 2001 From: Luca Chang <131398524+LucaButBoring@users.noreply.github.com> Date: Mon, 24 Nov 2025 11:29:00 -0800 Subject: [PATCH 64/84] Update example code for tasks to match impl Co-authored-by: Felix Weinberger <3823880+felixweinberger@users.noreply.github.com> --- src/examples/server/simpleStreamableHttp.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 796e9d13f..5dc926533 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -478,13 +478,13 @@ const getServer = () => { pollInterval: 100 }, 0, - { method: 'tools/call', params: { name: 'createTask', arguments: { duration } } } + { method: 'tools/call', params: { name: 'delay', arguments: { duration } } } ); // Simulate out-of-band work (async () => { await new Promise(resolve => setTimeout(resolve, duration)); - await taskStore.storeTaskResult(taskId, 'completed', { + await taskStore.storeTaskResult(task.taskId, 'completed', { content: [ { type: 'text', From cea8e6b096989735fdb7acb3015a19172992c2cd Mon Sep 17 00:00:00 2001 From: Luca Chang <131398524+LucaButBoring@users.noreply.github.com> Date: Mon, 24 Nov 2025 11:36:13 -0800 Subject: [PATCH 65/84] Update src/examples/server/simpleStreamableHttp.ts Co-authored-by: Felix Weinberger <3823880+felixweinberger@users.noreply.github.com> --- src/examples/server/simpleStreamableHttp.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 5dc926533..00afe1e7a 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -468,7 +468,6 @@ const getServer = () => { { async createTask({ duration }, { taskStore, taskRequestedTtl }) { // Generate a simple task ID (in production, use a more secure method) - const taskId = `task-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`; // Create the task const task = await taskStore.createTask( From 6839899f2a47d052e7ad8d42fc4e66d1b40a70e5 Mon Sep 17 00:00:00 2001 From: Luca Chang <131398524+LucaButBoring@users.noreply.github.com> Date: Mon, 24 Nov 2025 11:36:22 -0800 Subject: [PATCH 66/84] Update src/examples/server/simpleStreamableHttp.ts Co-authored-by: Felix Weinberger <3823880+felixweinberger@users.noreply.github.com> --- src/examples/server/simpleStreamableHttp.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 00afe1e7a..7f312a723 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -467,7 +467,6 @@ const getServer = () => { }, { async createTask({ duration }, { taskStore, taskRequestedTtl }) { - // Generate a simple task ID (in production, use a more secure method) // Create the task const task = await taskStore.createTask( From 0e84bd48040cba5fd321e40cc2ecaf2bb0941b47 Mon Sep 17 00:00:00 2001 From: Luca Chang <131398524+LucaButBoring@users.noreply.github.com> Date: Mon, 24 Nov 2025 11:36:45 -0800 Subject: [PATCH 67/84] Update src/examples/server/simpleStreamableHttp.ts Co-authored-by: Felix Weinberger <3823880+felixweinberger@users.noreply.github.com> --- src/examples/server/simpleStreamableHttp.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 7f312a723..c82fc94c2 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -471,7 +471,6 @@ const getServer = () => { // Create the task const task = await taskStore.createTask( { - taskId, ttl: taskRequestedTtl, pollInterval: 100 }, From a24fece487c1de41adff918052ccea326270071f Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 24 Nov 2025 11:53:42 -0800 Subject: [PATCH 68/84] Fix inconsistencies in task example, tweak interfaces --- src/examples/server/simpleStreamableHttp.ts | 8 ++-- src/examples/shared/inMemoryTaskStore.test.ts | 2 +- src/examples/shared/inMemoryTaskStore.ts | 8 ++-- src/shared/protocol.ts | 11 +++--- src/shared/task.ts | 39 ++++++++++++------- src/types.ts | 1 - 6 files changed, 39 insertions(+), 30 deletions(-) diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index c82fc94c2..026b0ac92 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -466,15 +466,13 @@ const getServer = () => { } }, { - async createTask({ duration }, { taskStore, taskRequestedTtl }) { - + async createTask({ duration }, { taskStore, taskRequestedTtl, requestId }) { // Create the task const task = await taskStore.createTask( { - ttl: taskRequestedTtl, - pollInterval: 100 + ttl: taskRequestedTtl }, - 0, + requestId, { method: 'tools/call', params: { name: 'delay', arguments: { duration } } } ); diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts index cba00d987..658e4deb1 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -32,7 +32,7 @@ describe('InMemoryTaskStore', () => { expect(task.taskId.length).toBeGreaterThan(0); expect(task.status).toBe('working'); expect(task.ttl).toBe(60000); - expect(task.pollInterval).toBe(500); + expect(task.pollInterval).toBeDefined(); expect(task.createdAt).toBeDefined(); expect(new Date(task.createdAt).getTime()).toBeGreaterThan(0); }); diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index ad0ab4de6..8e077cc19 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -1,5 +1,5 @@ -import { Task, TaskCreationParams, Request, RequestId, Result } from '../../types.js'; -import { TaskStore, isTerminal, TaskMessageQueue, QueuedMessage } from '../../shared/task.js'; +import { Task, Request, RequestId, Result } from '../../types.js'; +import { TaskStore, isTerminal, TaskMessageQueue, QueuedMessage, CreateTaskOptions } from '../../shared/task.js'; import { randomBytes } from 'crypto'; interface StoredTask { @@ -30,7 +30,7 @@ export class InMemoryTaskStore implements TaskStore { return randomBytes(16).toString('hex'); } - async createTask(taskParams: TaskCreationParams, requestId: RequestId, request: Request, _sessionId?: string): Promise { + async createTask(taskParams: CreateTaskOptions, requestId: RequestId, request: Request, _sessionId?: string): Promise { // Generate a unique task ID const taskId = this.generateTaskId(); @@ -47,7 +47,7 @@ export class InMemoryTaskStore implements TaskStore { status: 'working', ttl: actualTtl, createdAt: new Date().toISOString(), - pollInterval: taskParams.pollInterval ?? 500 + pollInterval: taskParams.pollInterval ?? 1000 }; this.tasks.set(taskId, { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 4604737e9..1dc8b0c3a 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -45,7 +45,7 @@ import { } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; -import { isTerminal, TaskStore, TaskMessageQueue, QueuedMessage } from './task.js'; +import { isTerminal, TaskStore, TaskMessageQueue, QueuedMessage, CreateTaskOptions } from './task.js'; import { getMethodLiteral, parseWithCompat } from '../server/zod-json-schema-compat.js'; import { ResponseMessage } from './responseMessage.js'; @@ -180,18 +180,19 @@ export interface RequestTaskStore { * Creates a new task with the given creation parameters. * The implementation generates a unique taskId and createdAt timestamp. * - * @param taskParams - The task creation parameters from the request (ttl, pollInterval) + * @param taskParams - The task creation parameters from the request * @param requestId - The JSON-RPC request ID * @param request - The original request that triggered task creation - * @returns The task state including generated taskId, createdAt timestamp, status, ttl, pollInterval, and optional statusMessage + * @returns The created task object */ - createTask(taskParams: TaskCreationParams, requestId: RequestId, request: Request): Promise; + createTask(taskParams: CreateTaskOptions, requestId: RequestId, request: Request): Promise; /** * Gets the current status of a task. * * @param taskId - The task identifier - * @returns The task state including status, ttl, pollInterval, and optional statusMessage + * @returns The task object + * @throws If the task does not exist */ getTask(taskId: string): Promise; diff --git a/src/shared/task.ts b/src/shared/task.ts index 45f0e1769..ae4517f6f 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -1,14 +1,4 @@ -import { - Task, - TaskCreationParams, - Request, - RequestId, - Result, - JSONRPCRequest, - JSONRPCNotification, - JSONRPCResponse, - JSONRPCError -} from '../types.js'; +import { Task, Request, RequestId, Result, JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, JSONRPCError } from '../types.js'; /** * Represents a message queued for side-channel delivery via tasks/result. @@ -92,6 +82,27 @@ export interface TaskMessageQueue { dequeueAll(taskId: string, sessionId?: string): Promise; } +/** + * Task creation options. + */ +export interface CreateTaskOptions { + /** + * Time in milliseconds to keep task results available after completion. + * If null, the task has unlimited lifetime until manually cleaned up. + */ + ttl?: number | null; + + /** + * Time in milliseconds to wait between task status requests. + */ + pollInterval?: number; + + /** + * Additional context to pass to the task store. + */ + context?: Record; +} + /** * Interface for storing and retrieving task state and results. * @@ -114,16 +125,16 @@ export interface TaskStore { * @param requestId - The JSON-RPC request ID * @param request - The original request that triggered task creation * @param sessionId - Optional session ID for binding the task to a specific session - * @returns The task state including generated taskId, createdAt timestamp, status, ttl, pollInterval, and optional statusMessage + * @returns The created task object */ - createTask(taskParams: TaskCreationParams, requestId: RequestId, request: Request, sessionId?: string): Promise; + createTask(taskParams: CreateTaskOptions, requestId: RequestId, request: Request, sessionId?: string): Promise; /** * Gets the current status of a task. * * @param taskId - The task identifier * @param sessionId - Optional session ID for binding the query to a specific session - * @returns The task state including status, ttl, pollInterval, and optional statusMessage + * @returns The task object, or null if it does not exist */ getTask(taskId: string, sessionId?: string): Promise; diff --git a/src/types.ts b/src/types.ts index 843014b46..0c55e9dc1 100644 --- a/src/types.ts +++ b/src/types.ts @@ -32,7 +32,6 @@ export const CursorSchema = z.string(); /** * Task creation parameters, used to ask that the server create a task to represent a request. - * The taskId is generated by the receiver, not provided by the requestor. */ export const TaskCreationParamsSchema = z.looseObject({ /** From 05f697fb65f99361bb23775c6c7355c5f19a2c54 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 24 Nov 2025 12:10:46 -0800 Subject: [PATCH 69/84] Simplify request interface further Removes some more now-redundant parameters from RequestTaskStore.createTask and removes a redundant throw from the example. --- src/examples/server/simpleStreamableHttp.ts | 19 +++++-------------- src/shared/protocol.ts | 4 +--- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 026b0ac92..4a9c00e95 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -466,15 +466,11 @@ const getServer = () => { } }, { - async createTask({ duration }, { taskStore, taskRequestedTtl, requestId }) { + async createTask({ duration }, { taskStore, taskRequestedTtl }) { // Create the task - const task = await taskStore.createTask( - { - ttl: taskRequestedTtl - }, - requestId, - { method: 'tools/call', params: { name: 'delay', arguments: { duration } } } - ); + const task = await taskStore.createTask({ + ttl: taskRequestedTtl + }); // Simulate out-of-band work (async () => { @@ -495,12 +491,7 @@ const getServer = () => { }; }, async getTask(_args, { taskId, taskStore }) { - const task = await taskStore.getTask(taskId); - if (!task) { - throw new Error(`Task ${taskId} not found`); - } - - return task; + return await taskStore.getTask(taskId); }, async getTaskResult(_args, { taskId, taskStore }) { const result = await taskStore.getTaskResult(taskId); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 1dc8b0c3a..757a68ff5 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -181,11 +181,9 @@ export interface RequestTaskStore { * The implementation generates a unique taskId and createdAt timestamp. * * @param taskParams - The task creation parameters from the request - * @param requestId - The JSON-RPC request ID - * @param request - The original request that triggered task creation * @returns The created task object */ - createTask(taskParams: CreateTaskOptions, requestId: RequestId, request: Request): Promise; + createTask(taskParams: CreateTaskOptions): Promise; /** * Gets the current status of a task. From 76e27ea7dc61ea97797c870e9d77c6b6842df98f Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 24 Nov 2025 12:26:02 -0800 Subject: [PATCH 70/84] Remove usages of removed params in tests --- src/client/index.test.ts | 100 ++++++-------------- src/integration-tests/taskLifecycle.test.ts | 72 +++++--------- src/server/index.test.ts | 90 ++++++------------ src/server/mcp.test.ts | 34 ++----- src/shared/protocol.test.ts | 2 +- 5 files changed, 90 insertions(+), 208 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 7ab683b72..ba8c88af6 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -1767,13 +1767,9 @@ describe('Task-based execution', () => { }, { async createTask(_args, extra) { - const task = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - { method: 'tools/call', params: { name: 'test-tool', arguments: {} } } - ); + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); const result = { content: [{ type: 'text', text: 'Tool executed successfully!' }] @@ -1847,13 +1843,9 @@ describe('Task-based execution', () => { }, { async createTask(_args, extra) { - const task = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - { method: 'tools/call', params: { name: 'test-tool', arguments: {} } } - ); + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); const result = { content: [{ type: 'text', text: 'Success!' }] @@ -1928,13 +1920,9 @@ describe('Task-based execution', () => { }, { async createTask(_args, extra) { - const task = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - { method: 'tools/call', params: { name: 'test-tool', arguments: {} } } - ); + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); const result = { content: [{ type: 'text', text: 'Result data!' }] @@ -2007,13 +1995,9 @@ describe('Task-based execution', () => { }, { async createTask(_args, extra) { - const task = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - { method: 'tools/call', params: { name: 'test-tool', arguments: {} } } - ); + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); const result = { content: [{ type: 'text', text: 'Success!' }] @@ -2112,13 +2096,9 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); taskId = createdTask.taskId; } const result = { @@ -2209,13 +2189,9 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); taskId = createdTask.taskId; } const result = { @@ -2305,13 +2281,9 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); taskId = createdTask.taskId; } const result = { @@ -2400,13 +2372,9 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); taskId = createdTask.taskId; } const result = { @@ -2514,13 +2482,9 @@ describe('Task-based execution', () => { }, { async createTask({ id }, extra) { - const task = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - { method: 'tools/call', params: { name: 'test-tool', arguments: { id } } } - ); + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); const result = { content: [{ type: 'text', text: `Result for ${id || 'unknown'}` }] @@ -2786,13 +2750,9 @@ test('should respect server task capabilities', async () => { }, { async createTask(_args, extra) { - const task = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - { method: 'tools/call', params: { name: 'test-tool', arguments: {} } } - ); + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); const result = { content: [{ type: 'text', text: 'Success!' }] diff --git a/src/integration-tests/taskLifecycle.test.ts b/src/integration-tests/taskLifecycle.test.ts index 27b64b235..3aba46b07 100644 --- a/src/integration-tests/taskLifecycle.test.ts +++ b/src/integration-tests/taskLifecycle.test.ts @@ -54,14 +54,10 @@ describe('Task Lifecycle Integration Tests', () => { }, { async createTask({ duration, shouldFail }, extra) { - const task = await extra.taskStore.createTask( - { - ttl: 60000, - pollInterval: 100 - }, - 0, - { method: 'tools/call', params: { name: 'long-task', arguments: { duration, shouldFail } } } - ); + const task = await extra.taskStore.createTask({ + ttl: 60000, + pollInterval: 100 + }); // Simulate async work (async () => { @@ -111,14 +107,10 @@ describe('Task Lifecycle Integration Tests', () => { }, { async createTask({ userName }, extra) { - const task = await extra.taskStore.createTask( - { - ttl: 60000, - pollInterval: 100 - }, - 0, - { method: 'tools/call', params: { name: 'input-task', arguments: { userName } } } - ); + const task = await extra.taskStore.createTask({ + ttl: 60000, + pollInterval: 100 + }); // Perform async work that requires elicitation (async () => { @@ -415,14 +407,10 @@ describe('Task Lifecycle Integration Tests', () => { }, { async createTask({ requestCount }, extra) { - const task = await extra.taskStore.createTask( - { - ttl: 60000, - pollInterval: 100 - }, - 0, - { method: 'tools/call', params: { name: 'multi-request-task', arguments: { requestCount } } } - ); + const task = await extra.taskStore.createTask({ + ttl: 60000, + pollInterval: 100 + }); // Perform async work that sends multiple requests (async () => { @@ -912,14 +900,10 @@ describe('Task Lifecycle Integration Tests', () => { }, { async createTask({ messageCount }, extra) { - const task = await extra.taskStore.createTask( - { - ttl: 60000, - pollInterval: 100 - }, - 0, - { method: 'tools/call', params: { name: 'cancellable-task', arguments: { messageCount } } } - ); + const task = await extra.taskStore.createTask({ + ttl: 60000, + pollInterval: 100 + }); // Perform async work that queues messages (async () => { @@ -1114,14 +1098,10 @@ describe('Task Lifecycle Integration Tests', () => { }, { async createTask({ messageCount, delayBetweenMessages }, extra) { - const task = await extra.taskStore.createTask( - { - ttl: 60000, - pollInterval: 100 - }, - 0, - { method: 'tools/call', params: { name: 'streaming-task', arguments: { messageCount, delayBetweenMessages } } } - ); + const task = await extra.taskStore.createTask({ + ttl: 60000, + pollInterval: 100 + }); // Perform async work that sends messages over time (async () => { @@ -1334,14 +1314,10 @@ describe('Task Lifecycle Integration Tests', () => { }, { async createTask({ messageCount }, extra) { - const task = await extra.taskStore.createTask( - { - ttl: 60000, - pollInterval: 100 - }, - 0, - { method: 'tools/call', params: { name: 'quick-complete-task', arguments: { messageCount } } } - ); + const task = await extra.taskStore.createTask({ + ttl: 60000, + pollInterval: 100 + }); // Perform async work that queues messages and completes quickly (async () => { diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 2b6c1ae09..fe46c5706 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -2024,13 +2024,9 @@ describe('Task-based execution', () => { }, { async createTask(_args, extra) { - const task = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - { method: 'tools/call', params: { name: 'test-tool', arguments: {} } } - ); + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); // Simulate some async work (async () => { @@ -2220,13 +2216,9 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); taskId = createdTask.taskId; } @@ -2250,13 +2242,9 @@ describe('Task-based execution', () => { }, { async createTask(_args, extra) { - const task = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - { method: 'tools/call', params: { name: 'collect-info', arguments: {} } } - ); + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); // Perform async work that makes a nested request (async () => { @@ -2378,13 +2366,9 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); taskId = createdTask.taskId; } const result = { @@ -2463,13 +2447,9 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); taskId = createdTask.taskId; } const result = { @@ -2546,13 +2526,9 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); taskId = createdTask.taskId; } const result = { @@ -2631,13 +2607,9 @@ describe('Task-based execution', () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); taskId = createdTask.taskId; } const result = { @@ -2747,13 +2719,9 @@ describe('Task-based execution', () => { }, { async createTask({ delay, taskNum }, extra) { - const task = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - { method: 'tools/call', params: { name: 'async-tool', arguments: { delay, taskNum } } } - ); + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); // Simulate async work (async () => { @@ -2984,13 +2952,9 @@ test('should respect client task capabilities', async () => { // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask( - { - ttl: extra.taskRequestedTtl - }, - extra.requestId, - request - ); + const createdTask = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); taskId = createdTask.taskId; } const result = { diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 82d379ba7..bb25440ac 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -5749,10 +5749,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async ({ input }, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }, extra.requestId, { - method: 'tools/call', - params: { name: 'long-running-task', arguments: { input } } - }); + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout const store = extra.taskStore; @@ -5857,10 +5854,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async ({ value }, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }, extra.requestId, { - method: 'tools/call', - params: { name: 'optional-task', arguments: { value } } - }); + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout const store = extra.taskStore; @@ -5968,10 +5962,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async ({ data }, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }, extra.requestId, { - method: 'tools/call', - params: { name: 'task-tool', arguments: { data } } - }); + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout const store = extra.taskStore; @@ -6088,10 +6079,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async extra => { - const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }, extra.requestId, { - method: 'tools/call', - params: { name: 'failing-task', arguments: {} } - }); + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout const store = extra.taskStore; @@ -6197,10 +6185,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async extra => { - const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }, extra.requestId, { - method: 'tools/call', - params: { name: 'cancelled-task', arguments: {} } - }); + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout const store = extra.taskStore; @@ -6286,11 +6271,8 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async ({ input }, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }, extra.requestId, { - method: 'tools/call', - params: { name: 'invalid-task', arguments: { input } } - }); + createTask: async (_args, extra) => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); return { task }; }, getTask: async (_args, extra) => { @@ -6300,7 +6282,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } return task; }, - getTaskResult: async (_input, extra) => { + getTaskResult: async (_args, extra) => { const result = await extra.taskStore.getTaskResult(extra.taskId); return result as CallToolResult; } diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 9df10e484..feb48741e 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -2431,7 +2431,7 @@ describe('Progress notification support for tasks', () => { // Set up a request handler that will complete the task protocol.setRequestHandler(CallToolRequestSchema, async (request, extra) => { if (extra.taskStore) { - const task = await extra.taskStore.createTask({ ttl: 60000 }, extra.requestId, request); + const task = await extra.taskStore.createTask({ ttl: 60000 }); // Simulate async work then complete the task setTimeout(async () => { From 0d7e70d1ab92ce8ace6f15b12cd4480bc4cdf42a Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 24 Nov 2025 12:27:23 -0800 Subject: [PATCH 71/84] Update README to reflect latest task changes --- README.md | 163 ++++++++++++-------- src/examples/client/simpleStreamableHttp.ts | 2 +- 2 files changed, 97 insertions(+), 68 deletions(-) diff --git a/README.md b/README.md index e3e8dfe9a..8f3c40f98 100644 --- a/README.md +++ b/README.md @@ -1405,29 +1405,29 @@ import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprot // Implement TaskStore backed by your database (e.g., PostgreSQL, Redis, etc.) class MyTaskStore implements TaskStore { - async createTask(taskParams, requestId, request) { - // Generate unique taskId and createdAt timestamp - // Store task in your database - // Return Task object with generated taskId + async createTask(taskParams, requestId, request, sessionId?): Promise { + // Generate unique taskId and lastUpdatedAt/createdAt timestamps + // Store task in your database, using the session ID as a proxy to restrict unauthorized access + // Return final Task object } - async getTask(taskId) { + async getTask(taskId): Promise { // Retrieve task from your database } - async updateTaskStatus(taskId, status, statusMessage?) { + async updateTaskStatus(taskId, status, statusMessage?): Promise { // Update task status in your database } - async storeTaskResult(taskId, result) { + async storeTaskResult(taskId, result): Promise { // Store task result in your database } - async getTaskResult(taskId) { + async getTaskResult(taskId): Promise { // Retrieve task result from your database } - async listTasks(cursor?, sessionId?) { + async listTasks(cursor?, sessionId?): Promise<{ tasks: Task[]; nextCursor?: string }> { // List tasks with pagination support } } @@ -1441,51 +1441,74 @@ const server = new Server( }, { capabilities: { - tools: {} + tools: {}, + // Declare capabilities + tasks: { + list: {}, + cancel: {}, + requests: { + tools: { + // Declares support for tasks on tools/call + call: {} + } + } + } }, taskStore // Enable task support } ); -// Set up a long-running tool handler as usual -server.setRequestHandler(CallToolRequestSchema, async request => { - if (request.params.name === 'analyze-data') { - // Simulate long-running analysis - await new Promise(resolve => setTimeout(resolve, 30000)); +// Register a tool that supports tasks +server.registerToolTask( + 'my-echo-tool', + { + title: 'My Echo Tool', + description: 'A simple task-based echo tool.', + inputSchema: { + message: z.string().describe('Message to send') + } + }, + { + async createTask({ message }, { taskStore, taskRequestedTtl, requestId }) { + // Create the task + const task = await taskStore.createTask({ + ttl: taskRequestedTtl + }); - return { - content: [ - { - type: 'text', - text: 'Analysis complete!' - } - ] - }; - } - throw new Error('Unknown tool'); -}); + // Simulate out-of-band work + (async () => { + await new Promise(resolve => setTimeout(resolve, 5000)); + await taskStore.storeTaskResult(task.taskId, 'completed', { + content: [ + { + type: 'text', + text: message + } + ] + }); + })(); -server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'analyze-data', - description: 'Perform data analysis (long-running)', - inputSchema: { - type: 'object', - properties: { - dataset: { type: 'string' } - } - } + // Return CreateTaskResult with the created task + return { task }; + }, + async getTask(_args, { taskId, taskStore }) { + // Retrieve the task + return await taskStore.getTask(taskId); + }, + async getTaskResult(_args, { taskId, taskStore }) { + // Retrieve the result of the task + const result = await taskStore.getTaskResult(taskId); + return result as CallToolResult; } - ] -})); + } +); ``` -**Note**: See `src/examples/shared/inMemoryTaskStore.ts` in the SDK source for a reference implementation suitable for development and testing. +**Note**: See `src/examples/shared/inMemoryTaskStore.ts` in the SDK source for a reference task store implementation suitable for development and testing. #### Client-Side: Using Task-Based Execution -Clients use `beginCallTool()` to initiate task-based operations. The returned `PendingRequest` object provides automatic polling and status tracking: +Clients use `callToolStream()` to initiate task-augmented tool calls. The returned `AsyncGenerator` abstracts automatic polling and status updates: ```typescript import { Client } from '@modelcontextprotocol/sdk/client/index.js'; @@ -1498,35 +1521,40 @@ const client = new Client({ // ... connect to server ... -// Initiate a task-based tool call -const taskId = 'analysis-task-123'; -const pendingRequest = client.beginCallTool( +// Call the tool with task metadata using streaming API +const stream = client.callToolStream( { - name: 'analyze-data', - arguments: { dataset: 'user-data.csv' } + name: 'my-echo-tool', + arguments: { message: 'Hello, world!' } }, - CallToolResultSchema, - { - task: { - taskId, - keepAlive: 300000 // Keep results for 5 minutes after completion - } - } + CallToolResultSchema ); -// Option 1: Wait for completion with status callbacks -const result = await pendingRequest.result({ - onTaskCreated: () => { - console.log('Task created successfully'); - }, - onTaskStatus: task => { - console.log(`Task status: ${task.status}`); - // Status can be: 'submitted', 'working', 'input_required', 'completed', 'failed', or 'cancelled' +// Iterate the stream and handle stream events +let taskId = ''; +for await (const message of stream) { + switch (message.type) { + case 'taskCreated': + console.log('Task created successfully with ID:', message.task.taskId); + taskId = message.task.taskId; + break; + case 'taskStatus': + console.log(` ${message.task.status}${message.task.statusMessage ?? ''}`); + break; + case 'result': + console.log('Task completed! Tool result:'); + message.result.content.forEach(item => { + if (item.type === 'text') { + console.log(` ${item.text}`); + } + }); + break; + case 'error': + throw message.error; } -}); -console.log('Task completed:', result); +} -// Option 2: Fire and forget - disconnect and reconnect later +// Optional: Fire and forget - disconnect and reconnect later // (useful when you don't want to wait for long-running tasks) // Later, after disconnecting and reconnecting to the server: const taskStatus = await client.getTask({ taskId }); @@ -1538,19 +1566,20 @@ if (taskStatus.status === 'completed') { } ``` +The `callToolStream()` method also works with non-task tools, making it a drop-in replacement for `callTool()` in applications that support it. When used to invoke a tool that doesn't support tasks, the `taskCreated` and `taskStatus` events will not be emitted. + #### Task Status Lifecycle Tasks transition through the following states: -- **submitted**: Task has been created and queued - **working**: Task is actively being processed - **input_required**: Task is waiting for additional input (e.g., from elicitation) - **completed**: Task finished successfully - **failed**: Task encountered an error - **cancelled**: Task was cancelled by the client -- **unknown**: Task status could not be determined (terminal state, rarely occurs) -The `keepAlive` parameter determines how long the server retains task results after completion. This allows clients to retrieve results even after disconnecting and reconnecting. +The `ttl` parameter suggests how long the server will manage the task for. If the task duration exceeds this, the server may delete the task prematurely. The client's suggested value may be overridden by the server, and the final TTL will be provided in `Task.ttl` in +`taskCreated` and `taskStatus` events. ### Writing MCP Clients diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index 74f426217..4dbd109d6 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -837,7 +837,7 @@ async function callToolTask(name: string, args: Record): Promis for await (const message of stream) { switch (message.type) { case 'taskCreated': - console.log('Task created successfully'); + console.log('Task created successfully with ID:', message.task.taskId); break; case 'taskStatus': if (lastStatus !== message.task.status) { From 4029546e8a9588b6b3082435ca79769212c033ab Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 24 Nov 2025 12:34:37 -0800 Subject: [PATCH 72/84] Add lastUpdatedAt to Task --- src/examples/shared/inMemoryTaskStore.ts | 9 ++++- src/shared/protocol.test.ts | 51 +++++++++++++++++++++--- src/shared/task.test.ts | 34 +++++++++++++--- src/types.ts | 4 ++ 4 files changed, 85 insertions(+), 13 deletions(-) diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index 8e077cc19..0e3716bdf 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -41,12 +41,14 @@ export class InMemoryTaskStore implements TaskStore { const actualTtl = taskParams.ttl ?? null; - // Create task with generated ID and timestamp + // Create task with generated ID and timestamps + const createdAt = new Date().toISOString(); const task: Task = { taskId, status: 'working', ttl: actualTtl, - createdAt: new Date().toISOString(), + createdAt, + lastUpdatedAt: createdAt, pollInterval: taskParams.pollInterval ?? 1000 }; @@ -90,6 +92,7 @@ export class InMemoryTaskStore implements TaskStore { stored.result = result; stored.task.status = status; + stored.task.lastUpdatedAt = new Date().toISOString(); // Reset cleanup timer to start from now (if ttl is set) if (stored.task.ttl) { @@ -138,6 +141,8 @@ export class InMemoryTaskStore implements TaskStore { stored.task.statusMessage = statusMessage; } + stored.task.lastUpdatedAt = new Date().toISOString(); + // If task is in a terminal state and has ttl, start cleanup timer if (isTerminal(status) && stored.task.ttl) { const existingTimer = this.cleanupTimers.get(taskId); diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index feb48741e..9ec39c871 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -54,11 +54,13 @@ function createMockTaskStore(options?: { createTask: vi.fn((taskParams: TaskCreationParams, _1: RequestId, _2: Request) => { // Generate a unique task ID const taskId = `test-task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + const createdAt = new Date().toISOString(); const task = (tasks[taskId] = { taskId, status: 'working', ttl: taskParams.ttl ?? null, - createdAt: new Date().toISOString(), + createdAt, + lastUpdatedAt: createdAt, pollInterval: taskParams.pollInterval ?? 1000 }); options?.onStatus?.('working'); @@ -1335,8 +1337,22 @@ describe('Task-based execution', () => { expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(3); expect(sentMessage.result.tasks).toEqual([ - { taskId: task1.taskId, status: 'completed', ttl: null, createdAt: expect.any(String), pollInterval: 500 }, - { taskId: task2.taskId, status: 'working', ttl: 60000, createdAt: expect.any(String), pollInterval: 1000 } + { + taskId: task1.taskId, + status: 'completed', + ttl: null, + createdAt: expect.any(String), + lastUpdatedAt: expect.any(String), + pollInterval: 500 + }, + { + taskId: task2.taskId, + status: 'working', + ttl: 60000, + createdAt: expect.any(String), + lastUpdatedAt: expect.any(String), + pollInterval: 1000 + } ]); expect(sentMessage.result._meta).toEqual({}); }); @@ -1384,7 +1400,14 @@ describe('Task-based execution', () => { expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(2); expect(sentMessage.result.tasks).toEqual([ - { taskId: task3.taskId, status: 'working', ttl: null, createdAt: expect.any(String), pollInterval: 500 } + { + taskId: task3.taskId, + status: 'working', + ttl: null, + createdAt: expect.any(String), + lastUpdatedAt: expect.any(String), + pollInterval: 500 + } ]); expect(sentMessage.result.nextCursor).toBeUndefined(); expect(sentMessage.result._meta).toEqual({}); @@ -1472,7 +1495,16 @@ describe('Task-based execution', () => { jsonrpc: '2.0', id: sendSpy.mock.calls[0][0].id, result: { - tasks: [{ taskId: 'task-1', status: 'completed', ttl: null, createdAt: '2024-01-01T00:00:00Z', pollInterval: 500 }], + tasks: [ + { + taskId: 'task-1', + status: 'completed', + ttl: null, + createdAt: '2024-01-01T00:00:00Z', + lastUpdatedAt: '2024-01-01T00:00:00Z', + pollInterval: 500 + } + ], nextCursor: undefined, _meta: {} } @@ -1504,7 +1536,14 @@ describe('Task-based execution', () => { id: sendSpy.mock.calls[0][0].id, result: { tasks: [ - { taskId: 'task-11', status: 'working', ttl: 30000, createdAt: '2024-01-01T00:00:00Z', pollInterval: 1000 } + { + taskId: 'task-11', + status: 'working', + ttl: 30000, + createdAt: '2024-01-01T00:00:00Z', + lastUpdatedAt: '2024-01-01T00:00:00Z', + pollInterval: 1000 + } ], nextCursor: 'task-11', _meta: {} diff --git a/src/shared/task.test.ts b/src/shared/task.test.ts index 4d3843740..4d21e3dc3 100644 --- a/src/shared/task.test.ts +++ b/src/shared/task.test.ts @@ -28,11 +28,13 @@ describe('Task utility functions', () => { describe('Task Schema Validation', () => { it('should validate task with ttl field', () => { + const createdAt = new Date().toISOString(); const task: Task = { taskId: 'test-123', status: 'working', ttl: 60000, - createdAt: new Date().toISOString(), + createdAt, + lastUpdatedAt: createdAt, pollInterval: 1000 }; @@ -42,22 +44,26 @@ describe('Task Schema Validation', () => { }); it('should validate task with null ttl', () => { + const createdAt = new Date().toISOString(); const task: Task = { taskId: 'test-456', status: 'completed', ttl: null, - createdAt: new Date().toISOString() + createdAt, + lastUpdatedAt: createdAt }; expect(task.ttl).toBeNull(); }); it('should validate task with statusMessage field', () => { + const createdAt = new Date().toISOString(); const task: Task = { taskId: 'test-789', status: 'failed', ttl: null, - createdAt: new Date().toISOString(), + createdAt, + lastUpdatedAt: createdAt, statusMessage: 'Operation failed due to timeout' }; @@ -66,26 +72,44 @@ describe('Task Schema Validation', () => { it('should validate task with createdAt in ISO 8601 format', () => { const now = new Date(); + const createdAt = now.toISOString(); const task: Task = { taskId: 'test-iso', status: 'working', ttl: 30000, - createdAt: now.toISOString() + createdAt, + lastUpdatedAt: createdAt }; expect(task.createdAt).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/); expect(new Date(task.createdAt).getTime()).toBe(now.getTime()); }); + it('should validate task with lastUpdatedAt in ISO 8601 format', () => { + const now = new Date(); + const createdAt = now.toISOString(); + const task: Task = { + taskId: 'test-iso', + status: 'working', + ttl: 30000, + createdAt, + lastUpdatedAt: createdAt + }; + + expect(task.lastUpdatedAt).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/); + }); + it('should validate all task statuses', () => { const statuses: Task['status'][] = ['working', 'input_required', 'completed', 'failed', 'cancelled']; + const createdAt = new Date().toISOString(); statuses.forEach(status => { const task: Task = { taskId: `test-${status}`, status, ttl: null, - createdAt: new Date().toISOString() + createdAt, + lastUpdatedAt: createdAt }; expect(task.status).toBe(status); }); diff --git a/src/types.ts b/src/types.ts index 0c55e9dc1..47fba19de 100644 --- a/src/types.ts +++ b/src/types.ts @@ -652,6 +652,10 @@ export const TaskSchema = z.object({ * ISO 8601 timestamp when the task was created. */ createdAt: z.string(), + /** + * ISO 8601 timestamp when the task was last updated. + */ + lastUpdatedAt: z.string(), pollInterval: z.optional(z.number()), /** * Optional diagnostic message for failed tasks or other status information. From db3280d413989200f62ef9e93f903763305831ac Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 25 Nov 2025 08:55:51 -0800 Subject: [PATCH 73/84] Await a few promises --- src/client/index.test.ts | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 739fbba60..35ba5f46a 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -2284,8 +2284,8 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // Create a task on client - const pending = server.request( + // Create a task on client and wait for completion + const result = await server.request( { method: 'elicitation/create', params: { @@ -2301,6 +2301,10 @@ describe('Task-based execution', () => { { task: { ttl: 60000 } } ); + // Verify the result was returned correctly + expect(result.action).toBe('accept'); + expect(result.content).toEqual({ username: 'list-user' }); + // Get the task ID from the task list since it's generated automatically const taskList = await server.listTasks(); expect(taskList.tasks.length).toBeGreaterThan(0); @@ -2376,8 +2380,8 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // Create a task on client - const pending = server.request( + // Create a task on client and wait for completion + const result = await server.request( { method: 'elicitation/create', params: { @@ -2393,15 +2397,19 @@ describe('Task-based execution', () => { { task: { ttl: 60000 } } ); + // Verify the result was returned correctly + expect(result.action).toBe('accept'); + expect(result.content).toEqual({ username: 'result-user' }); + // Get the task ID from the task list since it's generated automatically const taskList = await server.listTasks(); expect(taskList.tasks.length).toBeGreaterThan(0); const taskId = taskList.tasks[0].taskId; - // Query task result - const result = await server.getTaskResult({ taskId }, ElicitResultSchema); - expect(result.action).toBe('accept'); - expect(result.content).toEqual({ username: 'result-user' }); + // Query task result using getTaskResult as well + const taskResult = await server.getTaskResult({ taskId }, ElicitResultSchema); + expect(taskResult.action).toBe('accept'); + expect(taskResult.content).toEqual({ username: 'result-user' }); }); test('should query task list from client using listTasks', async () => { @@ -2470,7 +2478,7 @@ describe('Task-based execution', () => { // Create multiple tasks on client const createdTaskIds: string[] = []; for (let i = 0; i < 2; i++) { - const pending = server.request( + const result = await server.request( { method: 'elicitation/create', params: { @@ -2486,6 +2494,10 @@ describe('Task-based execution', () => { { task: { ttl: 60000 } } ); + // Verify the result was returned correctly + expect(result.action).toBe('accept'); + expect(result.content).toEqual({ username: 'list-user' }); + // Get the task ID from the task list const taskList = await server.listTasks(); const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); From bdc5aa41468631e005527182f080ee818cd3acbc Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 25 Nov 2025 09:41:17 -0800 Subject: [PATCH 74/84] Validate against CreateTaskResult in low-level client/server --- src/client/index.test.ts | 151 ++++++++++++++---------------- src/client/index.ts | 65 ++++++++++++- src/server/index.test.ts | 194 +++++++++++++++++++-------------------- src/server/index.ts | 98 +++++++++++++++++++- 4 files changed, 324 insertions(+), 184 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 35ba5f46a..f3acfded5 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -19,7 +19,8 @@ import { ElicitResultSchema, ListRootsRequestSchema, ErrorCode, - McpError + McpError, + CreateTaskResultSchema } from '../types.js'; import { Transport } from '../shared/transport.js'; import { Server } from '../server/index.js'; @@ -2150,22 +2151,22 @@ describe('Task-based execution', () => { ); client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { - let taskId: string | undefined; + const result = { + action: 'accept', + content: { username: 'list-user' } + }; // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask({ + const task = await extra.taskStore.createTask({ ttl: extra.taskRequestedTtl }); - taskId = createdTask.taskId; - } - const result = { - action: 'accept', - content: { username: 'list-user' } - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; } + + // Return ElicitResult for non-task requests return result; }); @@ -2192,7 +2193,7 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Server creates task on client via elicitation - await server.request( + const createTaskResult = await server.request( { method: 'elicitation/create', params: { @@ -2207,14 +2208,14 @@ describe('Task-based execution', () => { } } }, - ElicitResultSchema, + CreateTaskResultSchema, { task: { ttl: 60000 } } ); - // Get the task ID from the task list since it's generated automatically - const taskList = await server.listTasks(); - expect(taskList.tasks.length).toBeGreaterThan(0); - const taskId = taskList.tasks[0].taskId; + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; // Verify task was created const task = await server.getTask({ taskId }); @@ -2243,22 +2244,22 @@ describe('Task-based execution', () => { ); client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { - let taskId: string | undefined; + const result = { + action: 'accept', + content: { username: 'list-user' } + }; // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask({ + const task = await extra.taskStore.createTask({ ttl: extra.taskRequestedTtl }); - taskId = createdTask.taskId; - } - const result = { - action: 'accept', - content: { username: 'list-user' } - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; } + + // Return ElicitResult for non-task requests return result; }); @@ -2284,8 +2285,8 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // Create a task on client and wait for completion - const result = await server.request( + // Create a task on client and wait for CreateTaskResult + const createTaskResult = await server.request( { method: 'elicitation/create', params: { @@ -2297,18 +2298,14 @@ describe('Task-based execution', () => { } } }, - ElicitResultSchema, + CreateTaskResultSchema, { task: { ttl: 60000 } } ); - // Verify the result was returned correctly - expect(result.action).toBe('accept'); - expect(result.content).toEqual({ username: 'list-user' }); - - // Get the task ID from the task list since it's generated automatically - const taskList = await server.listTasks(); - expect(taskList.tasks.length).toBeGreaterThan(0); - const taskId = taskList.tasks[0].taskId; + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; // Query task status const task = await server.getTask({ taskId }); @@ -2339,22 +2336,22 @@ describe('Task-based execution', () => { ); client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { - let taskId: string | undefined; + const result = { + action: 'accept', + content: { username: 'result-user' } + }; // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask({ + const task = await extra.taskStore.createTask({ ttl: extra.taskRequestedTtl }); - taskId = createdTask.taskId; - } - const result = { - action: 'accept', - content: { username: 'result-user' } - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; } + + // Return ElicitResult for non-task requests return result; }); @@ -2380,8 +2377,8 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // Create a task on client and wait for completion - const result = await server.request( + // Create a task on client and wait for CreateTaskResult + const createTaskResult = await server.request( { method: 'elicitation/create', params: { @@ -2393,20 +2390,16 @@ describe('Task-based execution', () => { } } }, - ElicitResultSchema, + CreateTaskResultSchema, { task: { ttl: 60000 } } ); - // Verify the result was returned correctly - expect(result.action).toBe('accept'); - expect(result.content).toEqual({ username: 'result-user' }); - - // Get the task ID from the task list since it's generated automatically - const taskList = await server.listTasks(); - expect(taskList.tasks.length).toBeGreaterThan(0); - const taskId = taskList.tasks[0].taskId; + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; - // Query task result using getTaskResult as well + // Query task result using getTaskResult const taskResult = await server.getTaskResult({ taskId }, ElicitResultSchema); expect(taskResult.action).toBe('accept'); expect(taskResult.content).toEqual({ username: 'result-user' }); @@ -2434,22 +2427,22 @@ describe('Task-based execution', () => { ); client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { - let taskId: string | undefined; + const result = { + action: 'accept', + content: { username: 'list-user' } + }; // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask({ + const task = await extra.taskStore.createTask({ ttl: extra.taskRequestedTtl }); - taskId = createdTask.taskId; - } - const result = { - action: 'accept', - content: { username: 'list-user' } - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; } + + // Return ElicitResult for non-task requests return result; }); @@ -2478,7 +2471,7 @@ describe('Task-based execution', () => { // Create multiple tasks on client const createdTaskIds: string[] = []; for (let i = 0; i < 2; i++) { - const result = await server.request( + const createTaskResult = await server.request( { method: 'elicitation/create', params: { @@ -2490,20 +2483,14 @@ describe('Task-based execution', () => { } } }, - ElicitResultSchema, + CreateTaskResultSchema, { task: { ttl: 60000 } } ); - // Verify the result was returned correctly - expect(result.action).toBe('accept'); - expect(result.content).toEqual({ username: 'list-user' }); - - // Get the task ID from the task list - const taskList = await server.listTasks(); - const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); - if (newTask) { - createdTaskIds.push(newTask.taskId); - } + // Verify CreateTaskResult structure and capture taskId + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + createdTaskIds.push(createTaskResult.task.taskId); } // Query task list diff --git a/src/client/index.ts b/src/client/index.ts index b33aa79d2..367f6c998 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -40,7 +40,10 @@ import { type Tool, type UnsubscribeRequest, ElicitResultSchema, - ElicitRequestSchema + ElicitRequestSchema, + CreateTaskResultSchema, + CreateMessageRequestSchema, + CreateMessageResultSchema } from '../types.js'; import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js'; import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js'; @@ -283,6 +286,20 @@ export class Client< const result = await Promise.resolve(handler(request, extra)); + // When task creation is requested, validate and return CreateTaskResult + if (params.task) { + const taskValidationResult = safeParse(CreateTaskResultSchema, result); + if (!taskValidationResult.success) { + const errorMessage = + taskValidationResult.error instanceof Error + ? taskValidationResult.error.message + : String(taskValidationResult.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); + } + return taskValidationResult.data; + } + + // For non-task requests, validate against ElicitResultSchema const validationResult = safeParse(ElicitResultSchema, result); if (!validationResult.success) { // Type guard: if success is false, error is guaranteed to exist @@ -311,7 +328,51 @@ export class Client< return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler); } - // Non-elicitation handlers use default behavior + if (method === 'sampling/createMessage') { + const wrappedHandler = async ( + request: SchemaOutput, + extra: RequestHandlerExtra + ): Promise => { + const validatedRequest = safeParse(CreateMessageRequestSchema, request); + if (!validatedRequest.success) { + const errorMessage = + validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid sampling request: ${errorMessage}`); + } + + const { params } = validatedRequest.data; + + const result = await Promise.resolve(handler(request, extra)); + + // When task creation is requested, validate and return CreateTaskResult + if (params.task) { + const taskValidationResult = safeParse(CreateTaskResultSchema, result); + if (!taskValidationResult.success) { + const errorMessage = + taskValidationResult.error instanceof Error + ? taskValidationResult.error.message + : String(taskValidationResult.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); + } + return taskValidationResult.data; + } + + // For non-task requests, validate against CreateMessageResultSchema + const validationResult = safeParse(CreateMessageResultSchema, result); + if (!validationResult.success) { + const errorMessage = + validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid sampling result: ${errorMessage}`); + } + + return validationResult.data; + }; + + // Install the wrapped handler + return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler); + } + + // Other handlers use default behavior return super.setRequestHandler(requestSchema, handler); } diff --git a/src/server/index.test.ts b/src/server/index.test.ts index a5970972b..6cec7427f 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -19,7 +19,8 @@ import { RequestSchema, ResultSchema, SetLevelRequestSchema, - SUPPORTED_PROTOCOL_VERSIONS + SUPPORTED_PROTOCOL_VERSIONS, + CreateTaskResultSchema } from '../types.js'; import { Server } from './index.js'; import { McpServer } from './mcp.js'; @@ -2362,22 +2363,22 @@ describe('Task-based execution', () => { ); client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { - let taskId: string | undefined; + const result = { + action: 'accept', + content: { username: 'server-test-user', confirmed: true } + }; // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask({ + const task = await extra.taskStore.createTask({ ttl: extra.taskRequestedTtl }); - taskId = createdTask.taskId; - } - const result = { - action: 'accept', - content: { username: 'server-test-user', confirmed: true } - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; } + + // Return ElicitResult for non-task requests return result; }); @@ -2391,7 +2392,7 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Server creates task on client via elicitation - await server.request( + const createTaskResult = await server.request( { method: 'elicitation/create', params: { @@ -2407,14 +2408,14 @@ describe('Task-based execution', () => { } } }, - ElicitResultSchema, + CreateTaskResultSchema, { task: { ttl: 60000 } } ); - // Get the task ID from the task list since it's generated automatically - const taskList = await server.listTasks(); - expect(taskList.tasks.length).toBeGreaterThan(0); - const taskId = taskList.tasks[0].taskId; + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; // Verify task was created const task = await server.getTask({ taskId }); @@ -2443,22 +2444,22 @@ describe('Task-based execution', () => { ); client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { - let taskId: string | undefined; + const result = { + action: 'accept', + content: { username: 'list-user' } + }; // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask({ + const task = await extra.taskStore.createTask({ ttl: extra.taskRequestedTtl }); - taskId = createdTask.taskId; - } - const result = { - action: 'accept', - content: { username: 'list-user' } - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; } + + // Return ElicitResult for non-task requests return result; }); @@ -2472,7 +2473,7 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create task - await server.request( + const createTaskResult = await server.request( { method: 'elicitation/create', params: { @@ -2484,14 +2485,14 @@ describe('Task-based execution', () => { } } }, - ElicitResultSchema, + CreateTaskResultSchema, { task: { ttl: 60000 } } ); - // Get the task ID from the task list since it's generated automatically - const taskList = await server.listTasks(); - expect(taskList.tasks.length).toBeGreaterThan(0); - const taskId = taskList.tasks[0].taskId; + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; // Query task const task = await server.getTask({ taskId }); @@ -2522,22 +2523,22 @@ describe('Task-based execution', () => { ); client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { - let taskId: string | undefined; + const result = { + action: 'accept', + content: { username: 'result-user', confirmed: true } + }; // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask({ + const task = await extra.taskStore.createTask({ ttl: extra.taskRequestedTtl }); - taskId = createdTask.taskId; - } - const result = { - action: 'accept', - content: { username: 'result-user', confirmed: true } - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; } + + // Return ElicitResult for non-task requests return result; }); @@ -2551,7 +2552,7 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Create task - await server.request( + const createTaskResult = await server.request( { method: 'elicitation/create', params: { @@ -2566,14 +2567,14 @@ describe('Task-based execution', () => { } } }, - ElicitResultSchema, + CreateTaskResultSchema, { task: { ttl: 60000 } } ); - // Get the task ID from the task list since it's generated automatically - const taskList = await server.listTasks(); - expect(taskList.tasks.length).toBeGreaterThan(0); - const taskId = taskList.tasks[0].taskId; + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; // Query result const result = await server.getTaskResult({ taskId }, ElicitResultSchema); @@ -2603,22 +2604,22 @@ describe('Task-based execution', () => { ); client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { - let taskId: string | undefined; + const result = { + action: 'accept', + content: { username: 'list-user' } + }; // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask({ + const task = await extra.taskStore.createTask({ ttl: extra.taskRequestedTtl }); - taskId = createdTask.taskId; - } - const result = { - action: 'accept', - content: { username: 'list-user' } - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; } + + // Return ElicitResult for non-task requests return result; }); @@ -2647,7 +2648,7 @@ describe('Task-based execution', () => { // Create multiple tasks const createdTaskIds: string[] = []; for (let i = 0; i < 2; i++) { - await server.request( + const createTaskResult = await server.request( { method: 'elicitation/create', params: { @@ -2659,16 +2660,14 @@ describe('Task-based execution', () => { } } }, - ElicitResultSchema, + CreateTaskResultSchema, { task: { ttl: 60000 } } ); - // Get the task ID from the task list - const taskList = await server.listTasks(); - const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); - if (newTask) { - createdTaskIds.push(newTask.taskId); - } + // Verify CreateTaskResult structure and capture taskId + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + createdTaskIds.push(createTaskResult.task.taskId); } // Query task list @@ -2948,22 +2947,22 @@ test('should respect client task capabilities', async () => { ); client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { - let taskId: string | undefined; + const result = { + action: 'accept', + content: { username: 'test-user' } + }; // Check if task creation is requested if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask({ + const task = await extra.taskStore.createTask({ ttl: extra.taskRequestedTtl }); - taskId = createdTask.taskId; - } - const result = { - action: 'accept', - content: { username: 'test-user' } - }; - if (taskId && extra.taskStore) { - await extra.taskStore.storeTaskResult(taskId, 'completed', result); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; } + + // Return ElicitResult for non-task requests return result; }); @@ -3005,29 +3004,28 @@ test('should respect client task capabilities', async () => { }); // These should work because client supports tasks - await expect( - server.request( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: 'Test', - requestedSchema: { - type: 'object', - properties: { username: { type: 'string' } } - } + const createTaskResult = await server.request( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: 'Test', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } } - }, - ElicitResultSchema, - { task: { ttl: 60000 } } - ) - ).resolves.not.toThrow(); - await expect(server.listTasks()).resolves.not.toThrow(); + } + }, + CreateTaskResultSchema, + { task: { ttl: 60000 } } + ); - // Get the task ID from the task list since it's generated automatically - const taskList = await server.listTasks(); - expect(taskList.tasks.length).toBeGreaterThan(0); - const taskId = taskList.tasks[0].taskId; + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; + + await expect(server.listTasks()).resolves.not.toThrow(); await expect(server.getTask({ taskId })).resolves.not.toThrow(); // This should throw because client doesn't support task creation for sampling/createMessage diff --git a/src/server/index.ts b/src/server/index.ts index e348514f7..23061bf98 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -33,11 +33,24 @@ import { SetLevelRequestSchema, SUPPORTED_PROTOCOL_VERSIONS, type ToolResultContent, - type ToolUseContent + type ToolUseContent, + CallToolRequestSchema, + CallToolResultSchema, + CreateTaskResultSchema } from '../types.js'; import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js'; import type { JsonSchemaType, jsonSchemaValidator } from '../validation/types.js'; -import { AnySchema, SchemaOutput } from './zod-compat.js'; +import { + AnyObjectSchema, + AnySchema, + getObjectShape, + isZ4Schema, + safeParse, + SchemaOutput, + type ZodV3Internal, + type ZodV4Internal +} from './zod-compat.js'; +import { RequestHandlerExtra } from '../shared/protocol.js'; export type ServerOptions = ProtocolOptions & { /** @@ -177,6 +190,87 @@ export class Server< this._capabilities = mergeCapabilities(this._capabilities, capabilities); } + /** + * Override request handler registration to enforce server-side validation for tools/call. + */ + public override setRequestHandler( + requestSchema: T, + handler: ( + request: SchemaOutput, + extra: RequestHandlerExtra + ) => ServerResult | ResultT | Promise + ): void { + const shape = getObjectShape(requestSchema); + const methodSchema = shape?.method; + if (!methodSchema) { + throw new Error('Schema is missing a method literal'); + } + + // Extract literal value using type-safe property access + let methodValue: unknown; + if (isZ4Schema(methodSchema)) { + const v4Schema = methodSchema as unknown as ZodV4Internal; + const v4Def = v4Schema._zod?.def; + methodValue = v4Def?.value ?? v4Schema.value; + } else { + const v3Schema = methodSchema as unknown as ZodV3Internal; + const legacyDef = v3Schema._def; + methodValue = legacyDef?.value ?? v3Schema.value; + } + + if (typeof methodValue !== 'string') { + throw new Error('Schema method literal must be a string'); + } + const method = methodValue; + + if (method === 'tools/call') { + const wrappedHandler = async ( + request: SchemaOutput, + extra: RequestHandlerExtra + ): Promise => { + const validatedRequest = safeParse(CallToolRequestSchema, request); + if (!validatedRequest.success) { + const errorMessage = + validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`); + } + + const { params } = validatedRequest.data; + + const result = await Promise.resolve(handler(request, extra)); + + // When task creation is requested, validate and return CreateTaskResult + if (params.task) { + const taskValidationResult = safeParse(CreateTaskResultSchema, result); + if (!taskValidationResult.success) { + const errorMessage = + taskValidationResult.error instanceof Error + ? taskValidationResult.error.message + : String(taskValidationResult.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); + } + return taskValidationResult.data; + } + + // For non-task requests, validate against CallToolResultSchema + const validationResult = safeParse(CallToolResultSchema, result); + if (!validationResult.success) { + const errorMessage = + validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); + } + + return validationResult.data; + }; + + // Install the wrapped handler + return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler); + } + + // Other handlers use default behavior + return super.setRequestHandler(requestSchema, handler); + } + protected assertCapabilityForMethod(method: RequestT['method']): void { switch (method as ServerRequest['method']) { case 'sampling/createMessage': From 4782f9dbe1c47f2a8e81a99071f75979c3fbcb36 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 25 Nov 2025 09:56:59 -0800 Subject: [PATCH 75/84] Use CreateTaskResult for task ID in tests --- src/client/index.test.ts | 20 ++++++++++++------- src/server/index.test.ts | 42 ++++++++++++++++++++++++---------------- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index f3acfded5..38af2a841 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -2013,16 +2013,22 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // Create a task - await client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + // Create a task using callToolStream to capture the task ID + let taskId: string | undefined; + const stream = client.callToolStream({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 } }); - // Get the task ID from the task list and query task result - const taskList = await client.listTasks(); - expect(taskList.tasks.length).toBeGreaterThan(0); - const taskId = taskList.tasks[0].taskId; - const result = await client.getTaskResult({ taskId }, CallToolResultSchema); + for await (const message of stream) { + if (message.type === 'taskCreated') { + taskId = message.task.taskId; + } + } + + expect(taskId).toBeDefined(); + + // Query task result using the captured task ID + const result = await client.getTaskResult({ taskId: taskId! }, CallToolResultSchema); expect(result.content).toEqual([{ type: 'text', text: 'Result data!' }]); }); diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 6cec7427f..b1fb8a77a 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -2076,28 +2076,32 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // Use callTool to create a task - await client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + // Use callToolStream to create a task and capture the task ID + let taskId: string | undefined; + const stream = client.callToolStream({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 } }); + for await (const message of stream) { + if (message.type === 'taskCreated') { + taskId = message.task.taskId; + } + } + + expect(taskId).toBeDefined(); + // Wait for the task to complete await new Promise(resolve => setTimeout(resolve, 50)); - // Get the task ID from the task list since it's generated automatically - const taskList = await client.listTasks(); - expect(taskList.tasks.length).toBeGreaterThan(0); - const taskId = taskList.tasks[0].taskId; - // Verify we can retrieve the task - const task = await client.getTask({ taskId }); + const task = await client.getTask({ taskId: taskId! }); expect(task).toBeDefined(); expect(task.status).toBe('completed'); // Verify we can retrieve the result - const result = await client.getTaskResult({ taskId }, CallToolResultSchema); + const result = await client.getTaskResult({ taskId: taskId! }, CallToolResultSchema); expect(result.content).toEqual([{ type: 'text', text: 'Tool executed successfully!' }]); // Cleanup @@ -2299,26 +2303,30 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // Call tool WITH task creation - await client.callTool({ name: 'collect-info', arguments: {} }, CallToolResultSchema, { + // Call tool WITH task creation using callToolStream to capture task ID + let taskId: string | undefined; + const stream = client.callToolStream({ name: 'collect-info', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 } }); + for await (const message of stream) { + if (message.type === 'taskCreated') { + taskId = message.task.taskId; + } + } + + expect(taskId).toBeDefined(); + // Wait for completion await new Promise(resolve => setTimeout(resolve, 50)); // Verify the nested elicitation request was made (related-task metadata is no longer automatically attached) expect(capturedElicitRequest).toBeDefined(); - // Get the task ID from the task list since it's generated automatically - const taskList = await client.listTasks(); - expect(taskList.tasks.length).toBeGreaterThan(0); - const taskId = taskList.tasks[0].taskId; - // Verify tool result was correct - const result = await client.getTaskResult({ taskId }, CallToolResultSchema); + const result = await client.getTaskResult({ taskId: taskId! }, CallToolResultSchema); expect(result.content).toEqual([ { type: 'text', From 681cf0d086844b0690d1a871a424abe5033e6cd0 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 25 Nov 2025 10:36:07 -0800 Subject: [PATCH 76/84] Fix TaskStatusNotificationParamsSchema --- src/shared/protocol.test.ts | 26 ++++++++++++-------------- src/shared/protocol.ts | 8 ++------ src/types.ts | 4 +--- 3 files changed, 15 insertions(+), 23 deletions(-) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 9ec39c871..267d70251 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -2657,13 +2657,12 @@ describe('Progress notification support for tasks', () => { jsonrpc: '2.0', method: 'notifications/tasks/status', params: { - task: { - taskId, - status: 'failed', - ttl: 60000, - createdAt: new Date().toISOString(), - statusMessage: 'Task failed' - } + taskId, + status: 'failed', + ttl: 60000, + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + statusMessage: 'Task failed' } }); } @@ -2753,13 +2752,12 @@ describe('Progress notification support for tasks', () => { jsonrpc: '2.0', method: 'notifications/tasks/status', params: { - task: { - taskId, - status: 'cancelled', - ttl: 60000, - createdAt: new Date().toISOString(), - statusMessage: 'User cancelled' - } + taskId, + status: 'cancelled', + ttl: 60000, + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + statusMessage: 'User cancelled' } }); } diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 757a68ff5..15d74fe7f 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -1497,9 +1497,7 @@ export abstract class Protocol Date: Tue, 25 Nov 2025 12:18:14 -0800 Subject: [PATCH 77/84] Add missing execution field to ListTools --- src/server/mcp.test.ts | 138 +++++++++++++++++++++++++++++++++++++++++ src/server/mcp.ts | 1 + 2 files changed, 139 insertions(+) diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index bb25440ac..c3fd30cc3 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -1774,6 +1774,144 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(result.tools[0]._meta).toBeUndefined(); }); + test('should include execution field in listTools response when tool has execution settings', async () => { + const taskStore = new InMemoryTaskStore(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Register a tool with execution.taskSupport + mcpServer.registerToolTask( + 'task-tool', + { + description: 'A tool with task support', + inputSchema: { input: z.string() }, + execution: { + taskSupport: 'required' + } + }, + { + createTask: async (_args, extra) => { + const task = await extra.taskStore.createTask({ ttl: 60000 }); + return { task }; + }, + getTask: async (_args, extra) => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) throw new Error('Task not found'); + return task; + }, + getTaskResult: async (_args, extra) => { + return (await extra.taskStore.getTaskResult(extra.taskId)) as CallToolResult; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe('task-tool'); + expect(result.tools[0].execution).toEqual({ + taskSupport: 'required' + }); + + taskStore.cleanup(); + }); + + test('should include execution field with taskSupport optional in listTools response', async () => { + const taskStore = new InMemoryTaskStore(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Register a tool with execution.taskSupport optional + mcpServer.registerToolTask( + 'optional-task-tool', + { + description: 'A tool with optional task support', + inputSchema: { input: z.string() }, + execution: { + taskSupport: 'optional' + } + }, + { + createTask: async (_args, extra) => { + const task = await extra.taskStore.createTask({ ttl: 60000 }); + return { task }; + }, + getTask: async (_args, extra) => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) throw new Error('Task not found'); + return task; + }, + getTaskResult: async (_args, extra) => { + return (await extra.taskStore.getTaskResult(extra.taskId)) as CallToolResult; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe('optional-task-tool'); + expect(result.tools[0].execution).toEqual({ + taskSupport: 'optional' + }); + + taskStore.cleanup(); + }); + test('should validate tool names according to SEP specification', () => { // Create a new server instance for this test const testServer = new McpServer({ diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 11707dc14..a727a4f33 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -136,6 +136,7 @@ export class McpServer { : EMPTY_OBJECT_JSON_SCHEMA; })(), annotations: tool.annotations, + execution: tool.execution, _meta: tool._meta }; From 333a2576bf629bd5081aa80a0a61bd2fb3be898a Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 26 Nov 2025 10:18:57 -0800 Subject: [PATCH 78/84] Remove bit that deletes related-task metadata unnecessarily --- src/shared/protocol.ts | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 15d74fe7f..674820a4a 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -423,22 +423,9 @@ export abstract class Protocol Date: Wed, 26 Nov 2025 11:25:14 -0800 Subject: [PATCH 79/84] Implement response/error queueing and fix input_required Responses and errors were incorrectly going through the original stream path instead of being queued. Also, extra.sendRequest was not setting the input_required status. These issues have been fixed and tests have been added/updated for them. Sequence diagram of the intended flow: ```mermaid sequenceDiagram participant C as Client Protocol participant CT as Client Transport participant ST as Server Transport participant S as Server Protocol participant TQ as TaskMessageQueue participant TS as TaskStore participant H as Async Handler Note over C,H: Phase 1: Task Creation activate C C->>CT: tools/call { task: { ttl: 60000 } } activate CT CT->>ST: HTTP POST activate ST ST->>S: _onrequest() activate S S->>TS: createTask() activate TS TS-->>S: Task { taskId, status: 'working' } deactivate TS S--)H: Start async handler (non-blocking) activate H S-->>ST: CreateTaskResult { task } deactivate S ST-->>CT: HTTP Response deactivate ST CT-->>C: CreateTaskResult deactivate CT deactivate C Note over C,H: Phase 2: Server Queues Elicitation Request H->>S: extra.sendRequest(elicitation, { relatedTask }) activate S S->>TQ: enqueue({ type: 'request', message: elicitation }) activate TQ TQ-->>S: OK deactivate TQ S->>S: Store resolver in _requestResolvers Note over S: Promise waiting... deactivate S H->>TS: updateTaskStatus('input_required') activate TS TS-->>H: OK deactivate TS Note over H: Blocked awaiting elicitation response Note over C,H: Phase 3: Client Polls Status activate C C->>CT: tasks/get { taskId } activate CT CT->>ST: HTTP POST activate ST ST->>S: _onrequest(GetTask) activate S S->>TS: getTask(taskId) activate TS TS-->>S: Task { status: 'input_required' } deactivate TS S-->>ST: Task deactivate S ST-->>CT: HTTP Response deactivate ST CT-->>C: Task { status: 'input_required' } deactivate CT deactivate C Note over C,H: Phase 4: Client Fetches Queued Messages activate C C->>CT: tasks/result { taskId } activate CT CT->>ST: HTTP POST activate ST ST->>S: _onrequest(GetTaskPayload) activate S S->>TQ: dequeue(taskId) activate TQ TQ-->>S: { type: 'request', message: elicitation } deactivate TQ S->>ST: send(elicitation, { relatedRequestId }) ST-->>CT: SSE Event: elicitation request Note over S: Handler blocks (task not terminal) Note over C,H: Phase 5: Client Handles & Responds CT->>C: _onrequest(elicitation) activate C Note over C: Extract relatedTaskId from _meta C->>C: Call ElicitRequestSchema handler C->>C: Check: relatedTaskId && _taskMessageQueue Note over C: _taskMessageQueue is undefined C->>CT: transport.send(response) CT->>ST: HTTP POST (elicitation response) deactivate C Note over C,H: Phase 6: Server Receives Response, Resolves Promise ST->>S: _onresponse(elicitation response) S->>S: Lookup resolver in _requestResolvers S->>S: resolver(response) Note over S: Promise resolves S-->>H: Elicitation result { action: 'accept', content } Note over H: Resumes execution Note over C,H: Phase 7: Task Completes H->>TS: storeTaskResult('completed', finalResult) activate TS TS-->>H: OK deactivate TS deactivate H Note over S: GetTaskPayload handler wakes up S->>TS: getTask(taskId) activate TS TS-->>S: Task { status: 'completed' } deactivate TS S->>TS: getTaskResult(taskId) activate TS TS-->>S: CallToolResult deactivate TS S-->>ST: Return final result deactivate S ST-->>CT: SSE Event: CallToolResult deactivate ST CT-->>C: CallToolResult { content: [...] } deactivate CT deactivate C ``` --- src/integration-tests/taskLifecycle.test.ts | 194 ++++++++++---------- src/shared/protocol.test.ts | 166 +++++++++++++++++ src/shared/protocol.ts | 154 +++++++++++----- 3 files changed, 364 insertions(+), 150 deletions(-) diff --git a/src/integration-tests/taskLifecycle.test.ts b/src/integration-tests/taskLifecycle.test.ts index 3aba46b07..9595df6d4 100644 --- a/src/integration-tests/taskLifecycle.test.ts +++ b/src/integration-tests/taskLifecycle.test.ts @@ -5,7 +5,14 @@ import { Client } from '../client/index.js'; import { StreamableHTTPClientTransport } from '../client/streamableHttp.js'; import { McpServer } from '../server/mcp.js'; import { StreamableHTTPServerTransport } from '../server/streamableHttp.js'; -import { CallToolResultSchema, CreateTaskResultSchema, ElicitRequestSchema, ElicitResultSchema, TaskSchema } from '../types.js'; +import { + CallToolResultSchema, + CreateTaskResultSchema, + ElicitRequestSchema, + ElicitResultSchema, + RELATED_TASK_META_KEY, + TaskSchema +} from '../types.js'; import { z } from 'zod'; import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; import type { TaskRequestOptions } from '../shared/protocol.js'; @@ -565,6 +572,15 @@ describe('Task Lifecycle Integration Tests', () => { describe('Input Required Flow', () => { it('should handle elicitation during tool execution', async () => { + // Complete flow phases: + // 1. Client creates task + // 2. Server queues elicitation request and sets status to input_required + // 3. Client polls tasks/get, sees input_required status + // 4. Client calls tasks/result to dequeue elicitation request + // 5. Client responds to elicitation + // 6. Server receives response, completes task + // 7. Client receives final result + const elicitClient = new Client( { name: 'test-client', @@ -577,26 +593,27 @@ describe('Task Lifecycle Integration Tests', () => { } ); + // Track elicitation request receipt + let elicitationReceived = false; + let elicitationRequestMeta: Record | undefined; + // Set up elicitation handler on client elicitClient.setRequestHandler(ElicitRequestSchema, async request => { - // Verify elicitation request structure - expect(request.params.message).toBe('What is your name?'); - expect(request.params.requestedSchema).toHaveProperty('properties'); + elicitationReceived = true; + elicitationRequestMeta = request.params._meta; - // Respond with user input - const response = { + return { action: 'accept' as const, content: { - userName: 'Alice' + userName: 'TestUser' } }; - return response; }); const transport = new StreamableHTTPClientTransport(baseUrl); await elicitClient.connect(transport); - // Create a task without userName (will trigger elicitation) + // Phase 1: Create task const createResult = await elicitClient.request( { method: 'tools/call', @@ -612,15 +629,36 @@ describe('Task Lifecycle Integration Tests', () => { ); const taskId = createResult.task.taskId; + expect(createResult.task.status).toBe('working'); - // Wait for elicitation to occur - await new Promise(resolve => setTimeout(resolve, 200)); + // Phase 2: Wait for server to queue elicitation and update status + // Poll tasks/get until we see input_required status + let taskStatus: string = 'working'; + const maxPolls = 20; + let polls = 0; + + while (taskStatus === 'working' && polls < maxPolls) { + await new Promise(resolve => setTimeout(resolve, createResult.task.pollInterval ?? 100)); + const task = await elicitClient.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + taskStatus = task.status; + polls++; + } - // Check if the elicitation request was queued + // Verify we saw input_required status (not completed or failed) + expect(taskStatus).toBe('input_required'); - // Call tasks/result to receive the queued elicitation request - // This should deliver the elicitation request via the side-channel - // and then deliver the final result after the client responds + // Phase 3: Call tasks/result to dequeue messages and get final result + // This should: + // - Deliver the queued elicitation request via SSE + // - Client handler responds + // - Server receives response, completes task + // - Return final result const result = await elicitClient.request( { method: 'tasks/result', @@ -629,76 +667,28 @@ describe('Task Lifecycle Integration Tests', () => { CallToolResultSchema ); - // Verify final result is delivered correctly - expect(result.content).toEqual([{ type: 'text', text: 'Hello, Alice!' }]); + // Verify elicitation was received and processed + expect(elicitationReceived).toBe(true); - // Verify task is now completed - const task = await elicitClient.request( - { - method: 'tasks/get', - params: { taskId } - }, - TaskSchema - ); - expect(task.status).toBe('completed'); - - await transport.close(); - }, 10000); // Increase timeout to 10 seconds for debugging + // Verify the elicitation request had related-task metadata + expect(elicitationRequestMeta).toBeDefined(); + expect(elicitationRequestMeta?.[RELATED_TASK_META_KEY]).toEqual({ taskId }); - it('should complete immediately when input is provided upfront', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create a task with userName provided (no elicitation needed) - const createResult = await client.request( - { - method: 'tools/call', - params: { - name: 'input-task', - arguments: { - userName: 'Bob' - }, - task: { - ttl: 60000 - } - } - }, - CreateTaskResultSchema - ); - - const taskId = createResult.task.taskId; - - // Wait for completion - await new Promise(resolve => setTimeout(resolve, 300)); + // Verify final result + expect(result.content).toEqual([{ type: 'text', text: 'Hello, TestUser!' }]); - // Verify task completed without elicitation - const task = await client.request( + // Verify task is now completed + const finalTask = await elicitClient.request( { method: 'tasks/get', params: { taskId } }, TaskSchema ); - expect(task.status).toBe('completed'); - - // Get result - const result = await client.request( - { - method: 'tasks/result', - params: { taskId } - }, - CallToolResultSchema - ); - - expect(result.content).toEqual([{ type: 'text', text: 'Hello, Bob!' }]); + expect(finalTask.status).toBe('completed'); await transport.close(); - }); + }, 15000); }); describe('Task Listing and Pagination', () => { @@ -1013,7 +1003,7 @@ describe('Task Lifecycle Integration Tests', () => { // Wait for messages to be queued await new Promise(resolve => setTimeout(resolve, 200)); - // Verify task is working and messages are queued + // Verify task is in input_required state and messages are queued let task = await client.request( { method: 'tasks/get', @@ -1021,7 +1011,7 @@ describe('Task Lifecycle Integration Tests', () => { }, TaskSchema ); - expect(task.status).toBe('working'); + expect(task.status).toBe('input_required'); // Cancel the task before calling tasks/result using the proper tasks/cancel request // This will trigger queue cleanup via _clearTaskQueue in the handler @@ -1322,33 +1312,36 @@ describe('Task Lifecycle Integration Tests', () => { // Perform async work that queues messages and completes quickly (async () => { try { - // Queue messages without waiting for responses - const pendingRequests: Promise[] = []; - + // Queue messages - these will be queued before the task completes + // We await each one starting to ensure they're queued before completing for (let i = 0; i < messageCount; i++) { - const requestPromise = extra.sendRequest( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: `Quick message ${i + 1} of ${messageCount}`, - requestedSchema: { - type: 'object', - properties: { - response: { type: 'string' } - }, - required: ['response'] + // Start the request but don't wait for response + // The request gets queued when sendRequest is called + extra + .sendRequest( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: `Quick message ${i + 1} of ${messageCount}`, + requestedSchema: { + type: 'object', + properties: { + response: { type: 'string' } + }, + required: ['response'] + } } - } - }, - ElicitResultSchema, - { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions - ); - pendingRequests.push(requestPromise); + }, + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions + ) + .catch(() => {}); + // Small delay to ensure message is queued before next iteration + await new Promise(resolve => setTimeout(resolve, 10)); } - // Complete the task immediately (before responses are received) - // This creates a terminal task with queued messages + // Complete the task after all messages are queued try { await extra.taskStore.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: 'Task completed quickly' }] @@ -1356,9 +1349,6 @@ describe('Task Lifecycle Integration Tests', () => { } catch { // Task may have been cleaned up if test ended } - - // Wait for all responses in the background - await Promise.all(pendingRequests.map(p => p.catch(() => {}))); } catch (error) { // Handle errors try { diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 267d70251..eeb6a9237 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -3661,6 +3661,172 @@ describe('Message Interception', () => { }); }); + describe('server queues responses/errors for task-related requests', () => { + it('should queue response when handling a request with relatedTask metadata', async () => { + await protocol.connect(transport); + + // Set up a request handler that returns a result + const TestRequestSchema = z.object({ + method: z.literal('test/taskRequest'), + params: z + .object({ + _meta: z.optional(z.record(z.unknown())) + }) + .passthrough() + }); + + protocol.setRequestHandler(TestRequestSchema, async () => { + return { content: 'test result' } as Result; + }); + + // Simulate an incoming request with relatedTask metadata + const requestId = 456; + const taskId = 'task-response-test'; + transport.onmessage?.({ + jsonrpc: '2.0', + id: requestId, + method: 'test/taskRequest', + params: { + _meta: { + 'io.modelcontextprotocol/related-task': { taskId } + } + } + }); + + // Wait for the handler to complete + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify the response was queued instead of sent directly + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + const queuedMessage = await queue!.dequeue(taskId); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage!.type).toBe('response'); + if (queuedMessage!.type === 'response') { + expect(queuedMessage!.message.id).toBe(requestId); + expect(queuedMessage!.message.result).toEqual({ content: 'test result' }); + } + }); + + it('should queue error when handling a request with relatedTask metadata that throws', async () => { + await protocol.connect(transport); + + // Set up a request handler that throws an error + const TestRequestSchema = z.object({ + method: z.literal('test/taskRequestError'), + params: z + .object({ + _meta: z.optional(z.record(z.unknown())) + }) + .passthrough() + }); + + protocol.setRequestHandler(TestRequestSchema, async () => { + throw new McpError(ErrorCode.InternalError, 'Test error message'); + }); + + // Simulate an incoming request with relatedTask metadata + const requestId = 789; + const taskId = 'task-error-test'; + transport.onmessage?.({ + jsonrpc: '2.0', + id: requestId, + method: 'test/taskRequestError', + params: { + _meta: { + 'io.modelcontextprotocol/related-task': { taskId } + } + } + }); + + // Wait for the handler to complete + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify the error was queued instead of sent directly + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + const queuedMessage = await queue!.dequeue(taskId); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage!.type).toBe('error'); + if (queuedMessage!.type === 'error') { + expect(queuedMessage!.message.id).toBe(requestId); + expect(queuedMessage!.message.error.code).toBe(ErrorCode.InternalError); + expect(queuedMessage!.message.error.message).toContain('Test error message'); + } + }); + + it('should queue MethodNotFound error for unknown method with relatedTask metadata', async () => { + await protocol.connect(transport); + + // Simulate an incoming request for unknown method with relatedTask metadata + const requestId = 101; + const taskId = 'task-not-found-test'; + transport.onmessage?.({ + jsonrpc: '2.0', + id: requestId, + method: 'unknown/method', + params: { + _meta: { + 'io.modelcontextprotocol/related-task': { taskId } + } + } + }); + + // Wait for processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify the error was queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + const queuedMessage = await queue!.dequeue(taskId); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage!.type).toBe('error'); + if (queuedMessage!.type === 'error') { + expect(queuedMessage!.message.id).toBe(requestId); + expect(queuedMessage!.message.error.code).toBe(ErrorCode.MethodNotFound); + } + }); + + it('should send response normally when request has no relatedTask metadata', async () => { + await protocol.connect(transport); + const sendSpy = vi.spyOn(transport, 'send'); + + // Set up a request handler + const TestRequestSchema = z.object({ + method: z.literal('test/normalRequest'), + params: z.optional(z.record(z.unknown())) + }); + + protocol.setRequestHandler(TestRequestSchema, async () => { + return { content: 'normal result' } as Result; + }); + + // Simulate an incoming request WITHOUT relatedTask metadata + const requestId = 202; + transport.onmessage?.({ + jsonrpc: '2.0', + id: requestId, + method: 'test/normalRequest', + params: {} + }); + + // Wait for the handler to complete + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify the response was sent through transport, not queued + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + id: requestId, + result: { content: 'normal result' } + }) + ); + }); + }); + describe('messages without metadata bypass the queue', () => { it('should not queue notifications without relatedTask metadata', async () => { await protocol.connect(transport); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 674820a4a..a1fdb756e 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -650,17 +650,35 @@ export abstract class Protocol this._onerror(new Error(`Failed to send an error response: ${error}`))); + const errorResponse: JSONRPCError = { + jsonrpc: '2.0', + id: request.id, + error: { + code: ErrorCode.MethodNotFound, + message: 'Method not found' + } + }; + + // Queue or send the error response based on whether this is a task-related request + if (relatedTaskId && this._taskMessageQueue) { + this._enqueueTaskMessage( + relatedTaskId, + { + type: 'error', + message: errorResponse, + timestamp: Date.now() + }, + capturedTransport?.sessionId + ).catch(error => this._onerror(new Error(`Failed to enqueue error response: ${error}`))); + } else { + capturedTransport + ?.send(errorResponse) + .catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); + } return; } @@ -670,9 +688,6 @@ export abstract class Protocol = { signal: abortController.signal, sessionId: capturedTransport?.sessionId, @@ -691,6 +706,26 @@ export abstract class Protocol { if (abortController.signal.aborted) { @@ -731,7 +780,7 @@ export abstract class Protocol this._onerror(new Error(`Failed to send response: ${error}`))) @@ -1498,41 +1562,35 @@ export abstract class Protocol { - try { - // Check if task is in terminal state before attempting to update - const task = await taskStore.getTask(taskId, sessionId); - if (!task) { - return; - } + // Check if task exists + const task = await taskStore.getTask(taskId, sessionId); + if (!task) { + throw new McpError(ErrorCode.InvalidParams, `Task "${taskId}" not found - it may have been cleaned up`); + } - // Don't allow transitions from terminal states - if (isTerminal(task.status)) { - this._onerror( - new Error( - `Cannot update task "${taskId}" from terminal status "${task.status}" to "${status}". Terminal states (completed, failed, cancelled) cannot transition to other states.` - ) - ); - return; - } + // Don't allow transitions from terminal states + if (isTerminal(task.status)) { + throw new McpError( + ErrorCode.InvalidParams, + `Cannot update task "${taskId}" from terminal status "${task.status}" to "${status}". Terminal states (completed, failed, cancelled) cannot transition to other states.` + ); + } - await taskStore.updateTaskStatus(taskId, status, statusMessage, sessionId); + await taskStore.updateTaskStatus(taskId, status, statusMessage, sessionId); - // Get updated task state and send notification - const updatedTask = await taskStore.getTask(taskId, sessionId); - if (updatedTask) { - const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ - method: 'notifications/tasks/status', - params: updatedTask - }); - await this.notification(notification as SendNotificationT); + // Get updated task state and send notification + const updatedTask = await taskStore.getTask(taskId, sessionId); + if (updatedTask) { + const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ + method: 'notifications/tasks/status', + params: updatedTask + }); + await this.notification(notification as SendNotificationT); - if (isTerminal(updatedTask.status)) { - this._cleanupTaskProgressHandler(taskId); - // Don't clear queue here - it will be cleared after delivery via tasks/result - } + if (isTerminal(updatedTask.status)) { + this._cleanupTaskProgressHandler(taskId); + // Don't clear queue here - it will be cleared after delivery via tasks/result } - } catch (error) { - throw new Error(`Failed to update status of task "${taskId}" to "${status}": ${error}`); } }, listTasks: cursor => { From 6bb2444512edc9e22e371b1ce6323ac02b25f81a Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 26 Nov 2025 11:28:36 -0800 Subject: [PATCH 80/84] Let taskStore.updateTaskStatus throw directly if task is cancelled --- src/shared/protocol.ts | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index a1fdb756e..de8660688 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -706,26 +706,14 @@ export abstract class Protocol Date: Wed, 26 Nov 2025 12:01:46 -0800 Subject: [PATCH 81/84] Return task data in cancellation result --- src/integration-tests/taskLifecycle.test.ts | 95 ++++++++++++++------- src/shared/protocol.test.ts | 11 ++- src/shared/protocol.ts | 11 ++- src/shared/task-listing.test.ts | 12 ++- src/types.ts | 2 +- 5 files changed, 90 insertions(+), 41 deletions(-) diff --git a/src/integration-tests/taskLifecycle.test.ts b/src/integration-tests/taskLifecycle.test.ts index 9595df6d4..fb58b7d78 100644 --- a/src/integration-tests/taskLifecycle.test.ts +++ b/src/integration-tests/taskLifecycle.test.ts @@ -10,6 +10,8 @@ import { CreateTaskResultSchema, ElicitRequestSchema, ElicitResultSchema, + ErrorCode, + McpError, RELATED_TASK_META_KEY, TaskSchema } from '../types.js'; @@ -316,7 +318,7 @@ describe('Task Lifecycle Integration Tests', () => { }); describe('Task Cancellation', () => { - it('should cancel a working task', async () => { + it('should cancel a working task and return the cancelled task', async () => { const client = new Client({ name: 'test-client', version: '1.0.0' @@ -348,17 +350,24 @@ describe('Task Lifecycle Integration Tests', () => { let task = await taskStore.getTask(taskId); expect(task?.status).toBe('working'); - // Cancel the task - await taskStore.updateTaskStatus(taskId, 'cancelled'); + // Cancel the task via client.cancelTask - per spec, returns Result & Task + const cancelResult = await client.cancelTask({ taskId }); - // Verify task is cancelled + // Verify the cancel response includes the cancelled task (per MCP spec CancelTaskResult is Result & Task) + expect(cancelResult.taskId).toBe(taskId); + expect(cancelResult.status).toBe('cancelled'); + expect(cancelResult.createdAt).toBeDefined(); + expect(cancelResult.lastUpdatedAt).toBeDefined(); + expect(cancelResult.ttl).toBeDefined(); + + // Verify task is cancelled in store as well task = await taskStore.getTask(taskId); expect(task?.status).toBe('cancelled'); await transport.close(); }); - it('should reject cancellation of completed task', async () => { + it('should reject cancellation of completed task with error code -32602', async () => { const client = new Client({ name: 'test-client', version: '1.0.0' @@ -393,8 +402,13 @@ describe('Task Lifecycle Integration Tests', () => { const task = await taskStore.getTask(taskId); expect(task?.status).toBe('completed'); - // Try to cancel (should fail) - await expect(taskStore.updateTaskStatus(taskId, 'cancelled')).rejects.toThrow(); + // Try to cancel via tasks/cancel request (should fail with -32602) + await expect(client.cancelTask({ taskId })).rejects.toSatisfy((error: McpError) => { + expect(error).toBeInstanceOf(McpError); + expect(error.code).toBe(ErrorCode.InvalidParams); + expect(error.message).toContain('Cannot cancel task in terminal status'); + return true; + }); await transport.close(); }); @@ -775,7 +789,7 @@ describe('Task Lifecycle Integration Tests', () => { }); describe('Error Handling', () => { - it('should return null for non-existent task', async () => { + it('should return error code -32602 for non-existent task in tasks/get', async () => { const client = new Client({ name: 'test-client', version: '1.0.0' @@ -784,14 +798,18 @@ describe('Task Lifecycle Integration Tests', () => { const transport = new StreamableHTTPClientTransport(baseUrl); await client.connect(transport); - // Try to get non-existent task - const task = await taskStore.getTask('non-existent'); - expect(task).toBeNull(); + // Try to get non-existent task via tasks/get request + await expect(client.getTask({ taskId: 'non-existent-task-id' })).rejects.toSatisfy((error: McpError) => { + expect(error).toBeInstanceOf(McpError); + expect(error.code).toBe(ErrorCode.InvalidParams); + expect(error.message).toContain('Task not found'); + return true; + }); await transport.close(); }); - it('should return error for invalid task operation', async () => { + it('should return error code -32602 for non-existent task in tasks/cancel', async () => { const client = new Client({ name: 'test-client', version: '1.0.0' @@ -800,30 +818,41 @@ describe('Task Lifecycle Integration Tests', () => { const transport = new StreamableHTTPClientTransport(baseUrl); await client.connect(transport); - // Create and complete a task - const createResult = await client.request( - { - method: 'tools/call', - params: { - name: 'long-task', - arguments: { - duration: 100 - }, - task: { - ttl: 60000 - } - } - }, - CreateTaskResultSchema - ); + // Try to cancel non-existent task via tasks/cancel request + await expect(client.cancelTask({ taskId: 'non-existent-task-id' })).rejects.toSatisfy((error: McpError) => { + expect(error).toBeInstanceOf(McpError); + expect(error.code).toBe(ErrorCode.InvalidParams); + expect(error.message).toContain('Task not found'); + return true; + }); - const taskId = createResult.task.taskId; + await transport.close(); + }); - // Wait for completion - await new Promise(resolve => setTimeout(resolve, 200)); + it('should return error code -32602 for non-existent task in tasks/result', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); - // Try to cancel completed task (should fail) - await expect(taskStore.updateTaskStatus(taskId, 'cancelled')).rejects.toThrow(); + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Try to get result of non-existent task via tasks/result request + await expect( + client.request( + { + method: 'tasks/result', + params: { taskId: 'non-existent-task-id' } + }, + CallToolResultSchema + ) + ).rejects.toSatisfy((error: McpError) => { + expect(error).toBeInstanceOf(McpError); + expect(error.code).toBe(ErrorCode.InvalidParams); + expect(error.message).toContain('Task not found'); + return true; + }); await transport.close(); }); diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index eeb6a9237..f56c806fc 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -1715,13 +1715,18 @@ describe('Task-based execution', () => { const deleteTaskPromise = protocol.cancelTask({ taskId: 'task-to-delete' }); - // Simulate server response + // Simulate server response - per MCP spec, CancelTaskResult is Result & Task setTimeout(() => { transport.onmessage?.({ jsonrpc: '2.0', id: sendSpy.mock.calls[0][0].id, result: { - _meta: {} + _meta: {}, + taskId: 'task-to-delete', + status: 'cancelled', + ttl: 60000, + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString() } }); }, 0); @@ -1738,6 +1743,8 @@ describe('Task-based execution', () => { expect.any(Object) ); expect(result._meta).toBeDefined(); + expect(result.taskId).toBe('task-to-delete'); + expect(result.status).toBe('cancelled'); }); }); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index de8660688..31532d85a 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -507,9 +507,16 @@ export abstract class Protocol { let client: Client; @@ -125,14 +126,19 @@ describe('Task Listing with Pagination', () => { expect(result.tasks).toHaveLength(2); }); - it('should return error for invalid cursor', async () => { + it('should return error code -32602 for invalid cursor', async () => { await taskStore.createTask({}, 1, { method: 'tools/call', params: { name: 'test-tool' } }); - // Try to use an invalid cursor - await expect(client.listTasks({ cursor: 'invalid-cursor' })).rejects.toThrow(); + // Try to use an invalid cursor - should return -32602 (Invalid params) per MCP spec + await expect(client.listTasks({ cursor: 'invalid-cursor' })).rejects.toSatisfy((error: McpError) => { + expect(error).toBeInstanceOf(McpError); + expect(error.code).toBe(ErrorCode.InvalidParams); + expect(error.message).toContain('Invalid cursor'); + return true; + }); }); it('should ensure tasks accessible via tasks/get are also accessible via tasks/list', async () => { diff --git a/src/types.ts b/src/types.ts index 0763f628a..218393bf1 100644 --- a/src/types.ts +++ b/src/types.ts @@ -735,7 +735,7 @@ export const CancelTaskRequestSchema = RequestSchema.extend({ /** * The response to a tasks/cancel request. */ -export const CancelTaskResultSchema = ResultSchema; +export const CancelTaskResultSchema = ResultSchema.merge(TaskSchema); /* Resources */ /** From c09327674167adb69ca61db20426d551d61b911b Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Wed, 26 Nov 2025 14:47:07 +0000 Subject: [PATCH 82/84] feat(experimental): Create experimental/tasks module structure Phase 1-2 of tasks experimental isolation: - Create src/experimental/tasks/ directory structure - Move TaskStore, TaskMessageQueue, and related interfaces to experimental/tasks/interfaces.ts - Add experimental/tasks/types.ts for re-exporting spec types - Update shared/task.ts to re-export from experimental for backward compatibility - Add barrel exports for experimental module All tests pass (1399 tests). --- README.md | 16 +- package.json | 8 + src/client/index.test.ts | 54 ++-- src/client/index.ts | 220 ++------------- src/examples/client/simpleOAuthClient.ts | 3 +- src/examples/client/simpleStreamableHttp.ts | 3 +- src/examples/server/simpleStreamableHttp.ts | 5 +- src/experimental/index.ts | 13 + src/experimental/tasks/client.ts | 264 ++++++++++++++++++ src/experimental/tasks/helpers.ts | 88 ++++++ src/experimental/tasks/index.ts | 34 +++ .../tasks/interfaces.ts} | 102 ++++++- src/experimental/tasks/mcp-server.ts | 142 ++++++++++ src/experimental/tasks/server.ts | 131 +++++++++ .../tasks/stores/in-memory.test.ts} | 6 +- .../tasks/stores/in-memory.ts} | 15 +- .../tasks}/task-listing.test.ts | 28 +- .../tasks}/task.test.ts | 4 +- src/experimental/tasks/types.ts | 43 +++ src/integration-tests/taskLifecycle.test.ts | 24 +- src/server/index.test.ts | 44 +-- src/server/index.ts | 104 ++----- src/server/mcp.test.ts | 18 +- src/server/mcp.ts | 126 ++------- src/shared/protocol.test.ts | 81 ++++-- src/shared/protocol.ts | 22 +- 26 files changed, 1083 insertions(+), 515 deletions(-) create mode 100644 src/experimental/index.ts create mode 100644 src/experimental/tasks/client.ts create mode 100644 src/experimental/tasks/helpers.ts create mode 100644 src/experimental/tasks/index.ts rename src/{shared/task.ts => experimental/tasks/interfaces.ts} (69%) create mode 100644 src/experimental/tasks/mcp-server.ts create mode 100644 src/experimental/tasks/server.ts rename src/{examples/shared/inMemoryTaskStore.test.ts => experimental/tasks/stores/in-memory.test.ts} (99%) rename src/{examples/shared/inMemoryTaskStore.ts => experimental/tasks/stores/in-memory.ts} (96%) rename src/{shared => experimental/tasks}/task-listing.test.ts (83%) rename src/{shared => experimental/tasks}/task.test.ts (97%) create mode 100644 src/experimental/tasks/types.ts diff --git a/README.md b/README.md index d2da51d7f..6757a855a 100644 --- a/README.md +++ b/README.md @@ -1385,6 +1385,8 @@ const client = new Client( ### Task-Based Execution +> **⚠️ Experimental API**: Task-based execution is an experimental feature and may change without notice. Access these APIs via the `.experimental.tasks` namespace. + Task-based execution enables "call-now, fetch-later" patterns for long-running operations. This is useful for tools that take significant time to complete, where clients may want to disconnect and check on progress or retrieve results later. Common use cases include: @@ -1400,7 +1402,7 @@ To enable task-based execution, configure your server with a `TaskStore` impleme ```typescript import { Server } from '@modelcontextprotocol/sdk/server/index.js'; -import { TaskStore } from '@modelcontextprotocol/sdk/shared/task.js'; +import { TaskStore } from '@modelcontextprotocol/sdk/experimental'; import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'; // Implement TaskStore backed by your database (e.g., PostgreSQL, Redis, etc.) @@ -1458,8 +1460,8 @@ const server = new Server( } ); -// Register a tool that supports tasks -server.registerToolTask( +// Register a tool that supports tasks using the experimental API +server.experimental.tasks.registerToolTask( 'my-echo-tool', { title: 'My Echo Tool', @@ -1508,7 +1510,7 @@ server.registerToolTask( #### Client-Side: Using Task-Based Execution -Clients use `callToolStream()` to initiate task-augmented tool calls. The returned `AsyncGenerator` abstracts automatic polling and status updates: +Clients use `experimental.tasks.callToolStream()` to initiate task-augmented tool calls. The returned `AsyncGenerator` abstracts automatic polling and status updates: ```typescript import { Client } from '@modelcontextprotocol/sdk/client/index.js'; @@ -1521,8 +1523,8 @@ const client = new Client({ // ... connect to server ... -// Call the tool with task metadata using streaming API -const stream = client.callToolStream( +// Call the tool with task metadata using the experimental streaming API +const stream = client.experimental.tasks.callToolStream( { name: 'my-echo-tool', arguments: { message: 'Hello, world!' } @@ -1566,7 +1568,7 @@ if (taskStatus.status === 'completed') { } ``` -The `callToolStream()` method also works with non-task tools, making it a drop-in replacement for `callTool()` in applications that support it. When used to invoke a tool that doesn't support tasks, the `taskCreated` and `taskStatus` events will not be emitted. +The `experimental.tasks.callToolStream()` method also works with non-task tools, making it a drop-in replacement for `callTool()` in applications that support it. When used to invoke a tool that doesn't support tasks, the `taskCreated` and `taskStatus` events will not be emitted. #### Task Status Lifecycle diff --git a/package.json b/package.json index 9aa77ff2e..7eb668846 100644 --- a/package.json +++ b/package.json @@ -43,6 +43,14 @@ "import": "./dist/esm/validation/cfworker-provider.js", "require": "./dist/cjs/validation/cfworker-provider.js" }, + "./experimental": { + "import": "./dist/esm/experimental/index.js", + "require": "./dist/cjs/experimental/index.js" + }, + "./experimental/tasks": { + "import": "./dist/esm/experimental/tasks/index.js", + "require": "./dist/cjs/experimental/tasks/index.js" + }, "./*": { "import": "./dist/esm/*", "require": "./dist/cjs/*" diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 38af2a841..e161403fc 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -26,7 +26,7 @@ import { Transport } from '../shared/transport.js'; import { Server } from '../server/index.js'; import { McpServer } from '../server/mcp.js'; import { InMemoryTransport } from '../inMemory.js'; -import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../experimental/tasks/stores/in-memory.js'; import * as z3 from 'zod/v3'; import * as z4 from 'zod/v4'; @@ -1818,7 +1818,7 @@ describe('Task-based execution', () => { } ); - server.registerToolTask( + server.experimental.tasks.registerToolTask( 'test-tool', { description: 'A test tool', @@ -1868,7 +1868,7 @@ describe('Task-based execution', () => { }); // Verify task was created successfully by listing tasks - const taskList = await client.listTasks(); + const taskList = await client.experimental.tasks.listTasks(); expect(taskList.tasks.length).toBeGreaterThan(0); const task = taskList.tasks[0]; expect(task.status).toBe('completed'); @@ -1894,7 +1894,7 @@ describe('Task-based execution', () => { } ); - server.registerToolTask( + server.experimental.tasks.registerToolTask( 'test-tool', { description: 'A test tool', @@ -1942,7 +1942,7 @@ describe('Task-based execution', () => { }); // Query task status by listing tasks and getting the first one - const taskList = await client.listTasks(); + const taskList = await client.experimental.tasks.listTasks(); expect(taskList.tasks.length).toBeGreaterThan(0); const task = taskList.tasks[0]; expect(task).toBeDefined(); @@ -1971,7 +1971,7 @@ describe('Task-based execution', () => { } ); - server.registerToolTask( + server.experimental.tasks.registerToolTask( 'test-tool', { description: 'A test tool', @@ -2015,7 +2015,7 @@ describe('Task-based execution', () => { // Create a task using callToolStream to capture the task ID let taskId: string | undefined; - const stream = client.callToolStream({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + const stream = client.experimental.tasks.callToolStream({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 } }); @@ -2028,7 +2028,7 @@ describe('Task-based execution', () => { expect(taskId).toBeDefined(); // Query task result using the captured task ID - const result = await client.getTaskResult({ taskId: taskId! }, CallToolResultSchema); + const result = await client.experimental.tasks.getTaskResult(taskId!, CallToolResultSchema); expect(result.content).toEqual([{ type: 'text', text: 'Result data!' }]); }); @@ -2052,7 +2052,7 @@ describe('Task-based execution', () => { } ); - server.registerToolTask( + server.experimental.tasks.registerToolTask( 'test-tool', { description: 'A test tool', @@ -2103,7 +2103,7 @@ describe('Task-based execution', () => { }); // Get the task ID from the task list - const taskList = await client.listTasks(); + const taskList = await client.experimental.tasks.listTasks(); const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); if (newTask) { createdTaskIds.push(newTask.taskId); @@ -2111,7 +2111,7 @@ describe('Task-based execution', () => { } // Query task list - const taskList = await client.listTasks(); + const taskList = await client.experimental.tasks.listTasks(); expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); for (const taskId of createdTaskIds) { expect(taskList.tasks).toContainEqual( @@ -2224,7 +2224,7 @@ describe('Task-based execution', () => { const taskId = createTaskResult.task.taskId; // Verify task was created - const task = await server.getTask({ taskId }); + const task = await server.experimental.tasks.getTask(taskId); expect(task.status).toBe('completed'); }); @@ -2314,7 +2314,7 @@ describe('Task-based execution', () => { const taskId = createTaskResult.task.taskId; // Query task status - const task = await server.getTask({ taskId }); + const task = await server.experimental.tasks.getTask(taskId); expect(task).toBeDefined(); expect(task.taskId).toBe(taskId); expect(task.status).toBe('completed'); @@ -2406,7 +2406,7 @@ describe('Task-based execution', () => { const taskId = createTaskResult.task.taskId; // Query task result using getTaskResult - const taskResult = await server.getTaskResult({ taskId }, ElicitResultSchema); + const taskResult = await server.experimental.tasks.getTaskResult(taskId, ElicitResultSchema); expect(taskResult.action).toBe('accept'); expect(taskResult.content).toEqual({ username: 'result-user' }); }); @@ -2500,7 +2500,7 @@ describe('Task-based execution', () => { } // Query task list - const taskList = await server.listTasks(); + const taskList = await server.experimental.tasks.listTasks(); expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); for (const taskId of createdTaskIds) { expect(taskList.tasks).toContainEqual( @@ -2535,7 +2535,7 @@ describe('Task-based execution', () => { } ); - server.registerToolTask( + server.experimental.tasks.registerToolTask( 'test-tool', { description: 'A test tool', @@ -2601,7 +2601,7 @@ describe('Task-based execution', () => { }); // Get the task ID from the task list - const taskList = await client.listTasks(); + const taskList = await client.experimental.tasks.listTasks(); const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); if (newTask) { createdTaskIds.push(newTask.taskId); @@ -2609,13 +2609,13 @@ describe('Task-based execution', () => { } // List all tasks without cursor - const firstPage = await client.listTasks(); + const firstPage = await client.experimental.tasks.listTasks(); expect(firstPage.tasks.length).toBeGreaterThan(0); expect(firstPage.tasks.map(t => t.taskId)).toEqual(expect.arrayContaining(createdTaskIds)); // If there's a cursor, test pagination if (firstPage.nextCursor) { - const secondPage = await client.listTasks({ cursor: firstPage.nextCursor }); + const secondPage = await client.experimental.tasks.listTasks(firstPage.nextCursor); expect(secondPage.tasks).toBeDefined(); } @@ -2680,7 +2680,7 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Try to get a task that doesn't exist - await expect(client.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + await expect(client.experimental.tasks.getTask('non-existent-task')).rejects.toThrow(); }); test('should throw error when querying result of non-existent task from server', async () => { @@ -2727,7 +2727,7 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Try to get result of a task that doesn't exist - await expect(client.getTaskResult({ taskId: 'non-existent-task' }, CallToolResultSchema)).rejects.toThrow(); + await expect(client.experimental.tasks.getTaskResult('non-existent-task', CallToolResultSchema)).rejects.toThrow(); }); test('should throw error when server queries non-existent task from client', async () => { @@ -2779,7 +2779,7 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Try to query a task that doesn't exist on client - await expect(server.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + await expect(server.experimental.tasks.getTask('non-existent-task')).rejects.toThrow(); }); }); }); @@ -2805,7 +2805,7 @@ test('should respect server task capabilities', async () => { } ); - server.registerToolTask( + server.experimental.tasks.registerToolTask( 'test-tool', { description: 'A test tool', @@ -2871,7 +2871,7 @@ test('should respect server task capabilities', async () => { task: { ttl: 60000 } }) ).resolves.not.toThrow(); - await expect(client.listTasks()).resolves.not.toThrow(); + await expect(client.experimental.tasks.listTasks()).resolves.not.toThrow(); // tools/list doesn't support task creation, but it shouldn't throw - it should just ignore the task metadata await expect( @@ -2928,7 +2928,7 @@ test('should expose requestStream() method for streaming responses', async () => expect(regularResult.content).toEqual([{ type: 'text', text: 'Tool result' }]); // Test requestStream with non-task request (should yield only result) - const stream = client.requestStream( + const stream = client.experimental.tasks.requestStream( { method: 'tools/call', params: { name: 'test-tool', arguments: {} } @@ -2989,7 +2989,7 @@ test('should expose callToolStream() method for streaming tool calls', async () await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Test callToolStream - const stream = client.callToolStream({ name: 'test-tool', arguments: {} }); + const stream = client.experimental.tasks.callToolStream({ name: 'test-tool', arguments: {} }); const messages = []; for await (const message of stream) { @@ -3070,7 +3070,7 @@ test('should validate structured output in callToolStream()', async () => { await client.listTools(); // Test callToolStream with valid structured output - const stream = client.callToolStream({ name: 'structured-tool', arguments: {} }); + const stream = client.experimental.tasks.callToolStream({ name: 'structured-tool', arguments: {} }); const messages = []; for await (const message of stream) { diff --git a/src/client/index.ts b/src/client/index.ts index 367f6c998..d6dfe82d9 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,6 +1,6 @@ import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; import type { Transport } from '../shared/transport.js'; -import { ResponseMessage, takeResult } from '../shared/responseMessage.js'; +import { takeResult } from '../shared/responseMessage.js'; import { type CallToolRequest, @@ -57,6 +57,8 @@ import { type ZodV4Internal } from '../server/zod-compat.js'; import type { RequestHandlerExtra } from '../shared/protocol.js'; +import { ExperimentalClientTasks } from '../experimental/tasks/client.js'; +import { assertToolsCallTaskCapability, assertClientRequestTaskCapability } from '../experimental/tasks/helpers.js'; /** * Elicitation default application helper. Applies defaults to the data based on the schema. @@ -201,6 +203,7 @@ export class Client< private _jsonSchemaValidator: jsonSchemaValidator; private _cachedToolOutputValidators: Map> = new Map(); private _cachedKnownTaskTools: Set = new Set(); + private _experimental?: { tasks: ExperimentalClientTasks }; /** * Initializes this client with the given name and version information. @@ -214,6 +217,22 @@ export class Client< this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new AjvJsonSchemaValidator(); } + /** + * Access experimental features. + * + * WARNING: These APIs are experimental and may change without notice. + * + * @experimental + */ + get experimental(): { tasks: ExperimentalClientTasks } { + if (!this._experimental) { + this._experimental = { + tasks: new ExperimentalClientTasks(this) + }; + } + return this._experimental; + } + /** * Registers new capabilities. This can only be called before connecting to a transport. * @@ -568,23 +587,7 @@ export class Client< } protected assertTaskCapability(method: string): void { - if (!this._serverCapabilities?.tasks?.requests) { - throw new Error(`Server does not support task creation (required for ${method})`); - } - - const requests = this._serverCapabilities.tasks.requests; - - switch (method) { - case 'tools/call': - if (!requests.tools?.call) { - throw new Error(`Server does not support task creation for tools/call (required for ${method})`); - } - break; - - default: - // Method doesn't support tasks, which is fine - no error - break; - } + assertToolsCallTaskCapability(this._serverCapabilities?.tasks?.requests, method, 'Server'); } protected assertTaskHandlerCapability(method: string): void { @@ -594,29 +597,7 @@ export class Client< return; } - if (!this._capabilities.tasks?.requests) { - throw new Error(`Client does not support task creation (required for ${method})`); - } - - const requests = this._capabilities.tasks.requests; - - switch (method) { - case 'sampling/createMessage': - if (!requests.sampling?.createMessage) { - throw new Error(`Client does not support task creation for sampling/createMessage (required for ${method})`); - } - break; - - case 'elicitation/create': - if (!requests.elicitation?.create) { - throw new Error(`Client does not support task creation for elicitation/create (required for ${method})`); - } - break; - - default: - // Method doesn't support tasks, which is fine - no error - break; - } + assertClientRequestTaskCapability(this._capabilities.tasks?.requests, method, 'Client'); } async ping(options?: RequestOptions) { @@ -662,125 +643,15 @@ export class Client< /** * Calls a tool and waits for the result. Automatically validates structured output if the tool has an outputSchema. * - * For task-based execution with streaming behavior, use callToolStream() instead. + * For task-based execution with streaming behavior, use client.experimental.tasks.callToolStream() instead. */ async callTool( params: CallToolRequest['params'], resultSchema: T = CallToolResultSchema as T, options?: RequestOptions ): Promise> { - return await takeResult(this.callToolStream(params, resultSchema, options)); - } - - /** - * Calls a tool and returns an AsyncGenerator that yields response messages. - * The generator is guaranteed to end with either a 'result' or 'error' message. - * - * This method provides streaming access to tool execution, allowing you to - * observe intermediate task status updates for long-running tool calls. - * Automatically validates structured output if the tool has an outputSchema. - * - * For simple tool calls without streaming, use callTool() instead. - * - * @example - * ```typescript - * const stream = client.callToolStream({ name: 'myTool', arguments: {} }); - * for await (const message of stream) { - * switch (message.type) { - * case 'taskCreated': - * console.log('Tool execution started:', message.task.taskId); - * break; - * case 'taskStatus': - * console.log('Tool status:', message.task.status); - * break; - * case 'result': - * console.log('Tool result:', message.result); - * // Structured output is automatically validated - * break; - * case 'error': - * console.error('Tool error:', message.error); - * break; - * } - * } - * ``` - * - * @param params - Tool call parameters (name and arguments) - * @param resultSchema - Zod schema for validating the result (defaults to CallToolResultSchema) - * @param options - Optional request options (timeout, signal, task creation params, etc.) - * @returns AsyncGenerator that yields ResponseMessage objects - */ - async *callToolStream( - params: CallToolRequest['params'], - resultSchema: T = CallToolResultSchema as T, - options?: RequestOptions - ): AsyncGenerator>, void, void> { - // Add task creation parameters if server supports it and not explicitly provided - const optionsWithTask = { - ...options, - // We check if the tool is known to be a task during auto-configuration, but assume - // the caller knows what they're doing if they pass this explicitly - task: options?.task ?? (this.isToolTask(params.name) ? {} : undefined) - }; - - const stream = this.requestStream({ method: 'tools/call', params }, resultSchema, optionsWithTask); - - // Get the validator for this tool (if it has an output schema) - const validator = this.getToolOutputValidator(params.name); - - // Iterate through the stream and validate the final result if needed - for await (const message of stream) { - // If this is a result message and the tool has an output schema, validate it - if (message.type === 'result' && validator) { - const result = message.result; - - // If tool has outputSchema, it MUST return structuredContent (unless it's an error) - if (!result.structuredContent && !result.isError) { - yield { - type: 'error', - error: new McpError( - ErrorCode.InvalidRequest, - `Tool ${params.name} has an output schema but did not return structured content` - ) - }; - return; - } - - // Only validate structured content if present (not when there's an error) - if (result.structuredContent) { - try { - // Validate the structured content against the schema - const validationResult = validator(result.structuredContent); - - if (!validationResult.valid) { - yield { - type: 'error', - error: new McpError( - ErrorCode.InvalidParams, - `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` - ) - }; - return; - } - } catch (error) { - if (error instanceof McpError) { - yield { type: 'error', error }; - return; - } - yield { - type: 'error', - error: new McpError( - ErrorCode.InvalidParams, - `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` - ) - }; - return; - } - } - } - - // Yield the message (either validated result or any other message type) - yield message; - } + // Use experimental.tasks.callToolStream for implementation (temporary dependency) + return await takeResult(this.experimental.tasks.callToolStream(params, resultSchema, options)); } private isToolTask(toolName: string): boolean { @@ -833,45 +704,4 @@ export class Client< async sendRootsListChanged() { return this.notification({ method: 'notifications/roots/list_changed' }); } - - /** - * Sends a request and returns an AsyncGenerator that yields response messages. - * The generator is guaranteed to end with either a 'result' or 'error' message. - * - * This method provides streaming access to request processing, allowing you to - * observe intermediate task status updates for task-augmented requests. - * - * @example - * ```typescript - * const stream = client.requestStream(request, resultSchema, options); - * for await (const message of stream) { - * switch (message.type) { - * case 'taskCreated': - * console.log('Task created:', message.task.taskId); - * break; - * case 'taskStatus': - * console.log('Task status:', message.task.status); - * break; - * case 'result': - * console.log('Final result:', message.result); - * break; - * case 'error': - * console.error('Error:', message.error); - * break; - * } - * } - * ``` - * - * @param request - The request to send - * @param resultSchema - Zod schema for validating the result - * @param options - Optional request options (timeout, signal, task creation params, etc.) - * @returns AsyncGenerator that yields ResponseMessage objects - */ - requestStream( - request: ClientRequest | RequestT, - resultSchema: T, - options?: RequestOptions - ): AsyncGenerator>, void, void> { - return super.requestStream(request, resultSchema, options); - } } diff --git a/src/examples/client/simpleOAuthClient.ts b/src/examples/client/simpleOAuthClient.ts index 4dc724d25..8071e61ac 100644 --- a/src/examples/client/simpleOAuthClient.ts +++ b/src/examples/client/simpleOAuthClient.ts @@ -355,9 +355,10 @@ class InteractiveOAuthClient { } try { + // Using the experimental tasks API - WARNING: may change without notice console.log(`\n🔧 Streaming tool '${toolName}'...`); - const stream = this.client.callToolStream( + const stream = this.client.experimental.tasks.callToolStream( { name: toolName, arguments: toolArgs diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index 4dbd109d6..21ab4f556 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -814,11 +814,12 @@ async function callToolTask(name: string, args: Record): Promis console.log('Arguments:', args); // Use task-based execution - call now, fetch later + // Using the experimental tasks API - WARNING: may change without notice console.log('This will return immediately while processing continues in the background...'); try { // Call the tool with task metadata using streaming API - const stream = client.callToolStream( + const stream = client.experimental.tasks.callToolStream( { name, arguments: args diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 8a68ab2de..9d3afda97 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -15,7 +15,7 @@ import { ResourceLink } from '../../types.js'; import { InMemoryEventStore } from '../shared/inMemoryEventStore.js'; -import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../shared/inMemoryTaskStore.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../../experimental/tasks/stores/in-memory.js'; import { setupAuthServer } from './demoInMemoryOAuthProvider.js'; import { OAuthMetadata } from '../../shared/auth.js'; import { checkResourceAllowed } from '../../shared/auth-utils.js'; @@ -456,7 +456,8 @@ const getServer = () => { ); // Register a long-running tool that demonstrates task execution - server.registerToolTask( + // Using the experimental tasks API - WARNING: may change without notice + server.experimental.tasks.registerToolTask( 'delay', { title: 'Delay', diff --git a/src/experimental/index.ts b/src/experimental/index.ts new file mode 100644 index 000000000..55dd44ed0 --- /dev/null +++ b/src/experimental/index.ts @@ -0,0 +1,13 @@ +/** + * Experimental MCP SDK features. + * WARNING: These APIs are experimental and may change without notice. + * + * Import experimental features from this module: + * ```typescript + * import { TaskStore, InMemoryTaskStore } from '@modelcontextprotocol/sdk/experimental'; + * ``` + * + * @experimental + */ + +export * from './tasks/index.js'; diff --git a/src/experimental/tasks/client.ts b/src/experimental/tasks/client.ts new file mode 100644 index 000000000..f62941dc8 --- /dev/null +++ b/src/experimental/tasks/client.ts @@ -0,0 +1,264 @@ +/** + * Experimental client task features for MCP SDK. + * WARNING: These APIs are experimental and may change without notice. + * + * @experimental + */ + +import type { Client } from '../../client/index.js'; +import type { RequestOptions } from '../../shared/protocol.js'; +import type { ResponseMessage } from '../../shared/responseMessage.js'; +import type { AnyObjectSchema, SchemaOutput } from '../../server/zod-compat.js'; +import type { CallToolRequest, ClientRequest, Notification, Request, Result } from '../../types.js'; +import { CallToolResultSchema, type CompatibilityCallToolResultSchema, McpError, ErrorCode } from '../../types.js'; + +import type { GetTaskResult, ListTasksResult, CancelTaskResult } from './types.js'; + +/** + * Internal interface for accessing Client's private methods. + * @internal + */ +interface ClientInternal { + requestStream( + request: ClientRequest | RequestT, + resultSchema: T, + options?: RequestOptions + ): AsyncGenerator>, void, void>; + isToolTask(toolName: string): boolean; + getToolOutputValidator(toolName: string): ((data: unknown) => { valid: boolean; errorMessage?: string }) | undefined; +} + +/** + * Experimental task features for MCP clients. + * + * Access via `client.experimental.tasks`: + * ```typescript + * const stream = client.experimental.tasks.callToolStream({ name: 'tool', arguments: {} }); + * const task = await client.experimental.tasks.getTask(taskId); + * ``` + * + * @experimental + */ +export class ExperimentalClientTasks< + RequestT extends Request = Request, + NotificationT extends Notification = Notification, + ResultT extends Result = Result +> { + constructor(private readonly _client: Client) {} + + /** + * Calls a tool and returns an AsyncGenerator that yields response messages. + * The generator is guaranteed to end with either a 'result' or 'error' message. + * + * This method provides streaming access to tool execution, allowing you to + * observe intermediate task status updates for long-running tool calls. + * Automatically validates structured output if the tool has an outputSchema. + * + * @example + * ```typescript + * const stream = client.experimental.tasks.callToolStream({ name: 'myTool', arguments: {} }); + * for await (const message of stream) { + * switch (message.type) { + * case 'taskCreated': + * console.log('Tool execution started:', message.task.taskId); + * break; + * case 'taskStatus': + * console.log('Tool status:', message.task.status); + * break; + * case 'result': + * console.log('Tool result:', message.result); + * break; + * case 'error': + * console.error('Tool error:', message.error); + * break; + * } + * } + * ``` + * + * @param params - Tool call parameters (name and arguments) + * @param resultSchema - Zod schema for validating the result (defaults to CallToolResultSchema) + * @param options - Optional request options (timeout, signal, task creation params, etc.) + * @returns AsyncGenerator that yields ResponseMessage objects + * + * @experimental + */ + async *callToolStream( + params: CallToolRequest['params'], + resultSchema: T = CallToolResultSchema as T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + // Access Client's internal methods + const clientInternal = this._client as unknown as ClientInternal; + + // Add task creation parameters if server supports it and not explicitly provided + const optionsWithTask = { + ...options, + // We check if the tool is known to be a task during auto-configuration, but assume + // the caller knows what they're doing if they pass this explicitly + task: options?.task ?? (clientInternal.isToolTask(params.name) ? {} : undefined) + }; + + const stream = clientInternal.requestStream({ method: 'tools/call', params }, resultSchema, optionsWithTask); + + // Get the validator for this tool (if it has an output schema) + const validator = clientInternal.getToolOutputValidator(params.name); + + // Iterate through the stream and validate the final result if needed + for await (const message of stream) { + // If this is a result message and the tool has an output schema, validate it + if (message.type === 'result' && validator) { + const result = message.result; + + // If tool has outputSchema, it MUST return structuredContent (unless it's an error) + if (!result.structuredContent && !result.isError) { + yield { + type: 'error', + error: new McpError( + ErrorCode.InvalidRequest, + `Tool ${params.name} has an output schema but did not return structured content` + ) + }; + return; + } + + // Only validate structured content if present (not when there's an error) + if (result.structuredContent) { + try { + // Validate the structured content against the schema + const validationResult = validator(result.structuredContent); + + if (!validationResult.valid) { + yield { + type: 'error', + error: new McpError( + ErrorCode.InvalidParams, + `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` + ) + }; + return; + } + } catch (error) { + if (error instanceof McpError) { + yield { type: 'error', error }; + return; + } + yield { + type: 'error', + error: new McpError( + ErrorCode.InvalidParams, + `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` + ) + }; + return; + } + } + } + + // Yield the message (either validated result or any other message type) + yield message; + } + } + + /** + * Gets the current status of a task. + * + * @param taskId - The task identifier + * @param options - Optional request options + * @returns The task status + * + * @experimental + */ + async getTask(taskId: string, options?: RequestOptions): Promise { + // Delegate to the client's underlying Protocol method + type ClientWithGetTask = { getTask(params: { taskId: string }, options?: RequestOptions): Promise }; + return (this._client as unknown as ClientWithGetTask).getTask({ taskId }, options); + } + + /** + * Retrieves the result of a completed task. + * + * @param taskId - The task identifier + * @param resultSchema - Zod schema for validating the result + * @param options - Optional request options + * @returns The task result + * + * @experimental + */ + async getTaskResult(taskId: string, resultSchema?: T, options?: RequestOptions): Promise> { + // Delegate to the client's underlying Protocol method + return ( + this._client as unknown as { + getTaskResult: ( + params: { taskId: string }, + resultSchema?: U, + options?: RequestOptions + ) => Promise>; + } + ).getTaskResult({ taskId }, resultSchema, options); + } + + /** + * Lists tasks with optional pagination. + * + * @param cursor - Optional pagination cursor + * @param options - Optional request options + * @returns List of tasks with optional next cursor + * + * @experimental + */ + async listTasks(cursor?: string, options?: RequestOptions): Promise { + // Delegate to the client's underlying Protocol method + return ( + this._client as unknown as { + listTasks: (params?: { cursor?: string }, options?: RequestOptions) => Promise; + } + ).listTasks(cursor ? { cursor } : undefined, options); + } + + /** + * Cancels a running task. + * + * @param taskId - The task identifier + * @param options - Optional request options + * + * @experimental + */ + async cancelTask(taskId: string, options?: RequestOptions): Promise { + // Delegate to the client's underlying Protocol method + return ( + this._client as unknown as { + cancelTask: (params: { taskId: string }, options?: RequestOptions) => Promise; + } + ).cancelTask({ taskId }, options); + } + + /** + * Sends a request and returns an AsyncGenerator that yields response messages. + * The generator is guaranteed to end with either a 'result' or 'error' message. + * + * This method provides streaming access to request processing, allowing you to + * observe intermediate task status updates for task-augmented requests. + * + * @param request - The request to send + * @param resultSchema - Zod schema for validating the result + * @param options - Optional request options (timeout, signal, task creation params, etc.) + * @returns AsyncGenerator that yields ResponseMessage objects + * + * @experimental + */ + requestStream( + request: ClientRequest | RequestT, + resultSchema: T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + // Delegate to the client's underlying Protocol method + type ClientWithRequestStream = { + requestStream( + request: ClientRequest | RequestT, + resultSchema: U, + options?: RequestOptions + ): AsyncGenerator>, void, void>; + }; + return (this._client as unknown as ClientWithRequestStream).requestStream(request, resultSchema, options); + } +} diff --git a/src/experimental/tasks/helpers.ts b/src/experimental/tasks/helpers.ts new file mode 100644 index 000000000..34b15188f --- /dev/null +++ b/src/experimental/tasks/helpers.ts @@ -0,0 +1,88 @@ +/** + * Experimental task capability assertion helpers. + * WARNING: These APIs are experimental and may change without notice. + * + * @experimental + */ + +/** + * Type representing the task requests capability structure. + * This is derived from ClientTasksCapability.requests and ServerTasksCapability.requests. + */ +interface TaskRequestsCapability { + tools?: { call?: object }; + sampling?: { createMessage?: object }; + elicitation?: { create?: object }; +} + +/** + * Asserts that task creation is supported for tools/call. + * Used by Client.assertTaskCapability and Server.assertTaskHandlerCapability. + * + * @param requests - The task requests capability object + * @param method - The method being checked + * @param entityName - 'Server' or 'Client' for error messages + * @throws Error if the capability is not supported + * + * @experimental + */ +export function assertToolsCallTaskCapability( + requests: TaskRequestsCapability | undefined, + method: string, + entityName: 'Server' | 'Client' +): void { + if (!requests) { + throw new Error(`${entityName} does not support task creation (required for ${method})`); + } + + switch (method) { + case 'tools/call': + if (!requests.tools?.call) { + throw new Error(`${entityName} does not support task creation for tools/call (required for ${method})`); + } + break; + + default: + // Method doesn't support tasks, which is fine - no error + break; + } +} + +/** + * Asserts that task creation is supported for sampling/createMessage or elicitation/create. + * Used by Server.assertTaskCapability and Client.assertTaskHandlerCapability. + * + * @param requests - The task requests capability object + * @param method - The method being checked + * @param entityName - 'Server' or 'Client' for error messages + * @throws Error if the capability is not supported + * + * @experimental + */ +export function assertClientRequestTaskCapability( + requests: TaskRequestsCapability | undefined, + method: string, + entityName: 'Server' | 'Client' +): void { + if (!requests) { + throw new Error(`${entityName} does not support task creation (required for ${method})`); + } + + switch (method) { + case 'sampling/createMessage': + if (!requests.sampling?.createMessage) { + throw new Error(`${entityName} does not support task creation for sampling/createMessage (required for ${method})`); + } + break; + + case 'elicitation/create': + if (!requests.elicitation?.create) { + throw new Error(`${entityName} does not support task creation for elicitation/create (required for ${method})`); + } + break; + + default: + // Method doesn't support tasks, which is fine - no error + break; + } +} diff --git a/src/experimental/tasks/index.ts b/src/experimental/tasks/index.ts new file mode 100644 index 000000000..398d34393 --- /dev/null +++ b/src/experimental/tasks/index.ts @@ -0,0 +1,34 @@ +/** + * Experimental task features for MCP SDK. + * WARNING: These APIs are experimental and may change without notice. + * + * @experimental + */ + +// Re-export spec types for convenience +export * from './types.js'; + +// SDK implementation interfaces +export * from './interfaces.js'; + +// Assertion helpers +export * from './helpers.js'; + +// Wrapper classes +export * from './client.js'; +export * from './server.js'; +export * from './mcp-server.js'; + +// Store implementations +export * from './stores/in-memory.js'; + +// Re-export response message types for task streaming +export type { + ResponseMessage, + TaskStatusMessage, + TaskCreatedMessage, + ResultMessage, + ErrorMessage, + BaseResponseMessage +} from '../../shared/responseMessage.js'; +export { takeResult, toArrayAsync } from '../../shared/responseMessage.js'; diff --git a/src/shared/task.ts b/src/experimental/tasks/interfaces.ts similarity index 69% rename from src/shared/task.ts rename to src/experimental/tasks/interfaces.ts index ae4517f6f..4800e65dc 100644 --- a/src/shared/task.ts +++ b/src/experimental/tasks/interfaces.ts @@ -1,4 +1,98 @@ -import { Task, Request, RequestId, Result, JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, JSONRPCError } from '../types.js'; +/** + * Experimental task interfaces for MCP SDK. + * WARNING: These APIs are experimental and may change without notice. + */ + +import { + Task, + Request, + RequestId, + Result, + JSONRPCRequest, + JSONRPCNotification, + JSONRPCResponse, + JSONRPCError, + ServerRequest, + ServerNotification, + CallToolResult, + GetTaskResult, + ToolExecution +} from '../../types.js'; +import { CreateTaskResult } from './types.js'; +import type { RequestHandlerExtra, RequestTaskStore } from '../../shared/protocol.js'; +import type { ZodRawShapeCompat, AnySchema, ShapeOutput } from '../../server/zod-compat.js'; + +// ============================================================================ +// Task Handler Types (for registerToolTask) +// ============================================================================ + +/** + * Extended handler extra with task store for task creation. + * @experimental + */ +export interface CreateTaskRequestHandlerExtra extends RequestHandlerExtra { + taskStore: RequestTaskStore; +} + +/** + * Extended handler extra with task ID and store for task operations. + * @experimental + */ +export interface TaskRequestHandlerExtra extends RequestHandlerExtra { + taskId: string; + taskStore: RequestTaskStore; +} + +/** + * Base callback type for tool handlers. + * @experimental + */ +export type BaseToolCallback< + SendResultT extends Result, + ExtraT extends RequestHandlerExtra, + Args extends undefined | ZodRawShapeCompat | AnySchema = undefined +> = Args extends ZodRawShapeCompat + ? (args: ShapeOutput, extra: ExtraT) => SendResultT | Promise + : Args extends AnySchema + ? (args: unknown, extra: ExtraT) => SendResultT | Promise + : (extra: ExtraT) => SendResultT | Promise; + +/** + * Handler for creating a task. + * @experimental + */ +export type CreateTaskRequestHandler< + SendResultT extends Result, + Args extends undefined | ZodRawShapeCompat | AnySchema = undefined +> = BaseToolCallback; + +/** + * Handler for task operations (get, getResult). + * @experimental + */ +export type TaskRequestHandler< + SendResultT extends Result, + Args extends undefined | ZodRawShapeCompat | AnySchema = undefined +> = BaseToolCallback; + +/** + * Interface for task-based tool handlers. + * @experimental + */ +export interface ToolTaskHandler { + createTask: CreateTaskRequestHandler; + getTask: TaskRequestHandler; + getTaskResult: TaskRequestHandler; +} + +/** + * Task-specific execution configuration. + * taskSupport cannot be 'forbidden' for task-based tools. + * @experimental + */ +export type TaskToolExecution = Omit & { + taskSupport: TaskSupport extends 'forbidden' | undefined ? never : TaskSupport; +}; /** * Represents a message queued for side-channel delivery via tasks/result. @@ -51,6 +145,8 @@ export interface QueuedError extends BaseQueuedMessage { * * All methods are async to support external storage implementations. * All data in QueuedMessage must be JSON-serializable. + * + * @experimental */ export interface TaskMessageQueue { /** @@ -84,6 +180,7 @@ export interface TaskMessageQueue { /** * Task creation options. + * @experimental */ export interface CreateTaskOptions { /** @@ -108,6 +205,8 @@ export interface CreateTaskOptions { * * Similar to Transport, this allows pluggable task storage implementations * (in-memory, database, distributed cache, etc.). + * + * @experimental */ export interface TaskStore { /** @@ -183,6 +282,7 @@ export interface TaskStore { * * @param status - The task status to check * @returns True if the status is terminal (completed, failed, or cancelled) + * @experimental */ export function isTerminal(status: Task['status']): boolean { return status === 'completed' || status === 'failed' || status === 'cancelled'; diff --git a/src/experimental/tasks/mcp-server.ts b/src/experimental/tasks/mcp-server.ts new file mode 100644 index 000000000..506f3d72b --- /dev/null +++ b/src/experimental/tasks/mcp-server.ts @@ -0,0 +1,142 @@ +/** + * Experimental McpServer task features for MCP SDK. + * WARNING: These APIs are experimental and may change without notice. + * + * @experimental + */ + +import type { McpServer, RegisteredTool, AnyToolHandler } from '../../server/mcp.js'; +import type { ZodRawShapeCompat, AnySchema } from '../../server/zod-compat.js'; +import type { ToolAnnotations, ToolExecution } from '../../types.js'; +import type { ToolTaskHandler, TaskToolExecution } from './interfaces.js'; + +/** + * Internal interface for accessing McpServer's private _createRegisteredTool method. + * @internal + */ +interface McpServerInternal { + _createRegisteredTool( + name: string, + title: string | undefined, + description: string | undefined, + inputSchema: ZodRawShapeCompat | AnySchema | undefined, + outputSchema: ZodRawShapeCompat | AnySchema | undefined, + annotations: ToolAnnotations | undefined, + execution: ToolExecution | undefined, + _meta: Record | undefined, + handler: AnyToolHandler + ): RegisteredTool; +} + +/** + * Experimental task features for McpServer. + * + * Access via `server.experimental.tasks`: + * ```typescript + * server.experimental.tasks.registerToolTask('long-running', config, handler); + * ``` + * + * @experimental + */ +export class ExperimentalMcpServerTasks { + constructor(private readonly _mcpServer: McpServer) {} + + /** + * Registers a task-based tool with a config object and handler. + * + * Task-based tools support long-running operations that can be polled for status + * and results. The handler must implement `createTask`, `getTask`, and `getTaskResult` + * methods. + * + * @example + * ```typescript + * server.experimental.tasks.registerToolTask('long-computation', { + * description: 'Performs a long computation', + * inputSchema: { input: z.string() }, + * execution: { taskSupport: 'required' } + * }, { + * createTask: async (args, extra) => { + * const task = await extra.taskStore.createTask({ ttl: 300000 }); + * startBackgroundWork(task.taskId, args); + * return { task }; + * }, + * getTask: async (args, extra) => { + * return extra.taskStore.getTask(extra.taskId); + * }, + * getTaskResult: async (args, extra) => { + * return extra.taskStore.getTaskResult(extra.taskId); + * } + * }); + * ``` + * + * @param name - The tool name + * @param config - Tool configuration (description, schemas, etc.) + * @param handler - Task handler with createTask, getTask, getTaskResult methods + * @returns RegisteredTool for managing the tool's lifecycle + * + * @experimental + */ + registerToolTask( + name: string, + config: { + title?: string; + description?: string; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + execution?: TaskToolExecution; + _meta?: Record; + }, + handler: ToolTaskHandler + ): RegisteredTool; + + registerToolTask( + name: string, + config: { + title?: string; + description?: string; + inputSchema: InputArgs; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + execution?: TaskToolExecution; + _meta?: Record; + }, + handler: ToolTaskHandler + ): RegisteredTool; + + registerToolTask< + InputArgs extends undefined | ZodRawShapeCompat | AnySchema, + OutputArgs extends undefined | ZodRawShapeCompat | AnySchema + >( + name: string, + config: { + title?: string; + description?: string; + inputSchema?: InputArgs; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + execution?: TaskToolExecution; + _meta?: Record; + }, + handler: ToolTaskHandler + ): RegisteredTool { + // Validate that taskSupport is not 'forbidden' for task-based tools + const execution: ToolExecution = { taskSupport: 'required', ...config.execution }; + if (execution.taskSupport === 'forbidden') { + throw new Error(`Cannot register task-based tool '${name}' with taskSupport 'forbidden'. Use registerTool() instead.`); + } + + // Access McpServer's internal _createRegisteredTool method + const mcpServerInternal = this._mcpServer as unknown as McpServerInternal; + return mcpServerInternal._createRegisteredTool( + name, + config.title, + config.description, + config.inputSchema, + config.outputSchema, + config.annotations, + execution, + config._meta, + handler as AnyToolHandler + ); + } +} diff --git a/src/experimental/tasks/server.ts b/src/experimental/tasks/server.ts new file mode 100644 index 000000000..a4150a8d7 --- /dev/null +++ b/src/experimental/tasks/server.ts @@ -0,0 +1,131 @@ +/** + * Experimental server task features for MCP SDK. + * WARNING: These APIs are experimental and may change without notice. + * + * @experimental + */ + +import type { Server } from '../../server/index.js'; +import type { RequestOptions } from '../../shared/protocol.js'; +import type { ResponseMessage } from '../../shared/responseMessage.js'; +import type { AnySchema, SchemaOutput } from '../../server/zod-compat.js'; +import type { ServerRequest, Notification, Request, Result, GetTaskResult, ListTasksResult, CancelTaskResult } from '../../types.js'; + +/** + * Experimental task features for low-level MCP servers. + * + * Access via `server.experimental.tasks`: + * ```typescript + * const stream = server.experimental.tasks.requestStream(request, schema, options); + * ``` + * + * For high-level server usage with task-based tools, use `McpServer.experimental.tasks` instead. + * + * @experimental + */ +export class ExperimentalServerTasks< + RequestT extends Request = Request, + NotificationT extends Notification = Notification, + ResultT extends Result = Result +> { + constructor(private readonly _server: Server) {} + + /** + * Sends a request and returns an AsyncGenerator that yields response messages. + * The generator is guaranteed to end with either a 'result' or 'error' message. + * + * This method provides streaming access to request processing, allowing you to + * observe intermediate task status updates for task-augmented requests. + * + * @param request - The request to send + * @param resultSchema - Zod schema for validating the result + * @param options - Optional request options (timeout, signal, task creation params, etc.) + * @returns AsyncGenerator that yields ResponseMessage objects + * + * @experimental + */ + requestStream( + request: ServerRequest | RequestT, + resultSchema: T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + // Delegate to the server's underlying Protocol method + type ServerWithRequestStream = { + requestStream( + request: ServerRequest | RequestT, + resultSchema: U, + options?: RequestOptions + ): AsyncGenerator>, void, void>; + }; + return (this._server as unknown as ServerWithRequestStream).requestStream(request, resultSchema, options); + } + + /** + * Gets the current status of a task. + * + * @param taskId - The task identifier + * @param options - Optional request options + * @returns The task status + * + * @experimental + */ + async getTask(taskId: string, options?: RequestOptions): Promise { + type ServerWithGetTask = { getTask(params: { taskId: string }, options?: RequestOptions): Promise }; + return (this._server as unknown as ServerWithGetTask).getTask({ taskId }, options); + } + + /** + * Retrieves the result of a completed task. + * + * @param taskId - The task identifier + * @param resultSchema - Zod schema for validating the result + * @param options - Optional request options + * @returns The task result + * + * @experimental + */ + async getTaskResult(taskId: string, resultSchema?: T, options?: RequestOptions): Promise> { + return ( + this._server as unknown as { + getTaskResult: ( + params: { taskId: string }, + resultSchema?: U, + options?: RequestOptions + ) => Promise>; + } + ).getTaskResult({ taskId }, resultSchema, options); + } + + /** + * Lists tasks with optional pagination. + * + * @param cursor - Optional pagination cursor + * @param options - Optional request options + * @returns List of tasks with optional next cursor + * + * @experimental + */ + async listTasks(cursor?: string, options?: RequestOptions): Promise { + return ( + this._server as unknown as { + listTasks: (params?: { cursor?: string }, options?: RequestOptions) => Promise; + } + ).listTasks(cursor ? { cursor } : undefined, options); + } + + /** + * Cancels a running task. + * + * @param taskId - The task identifier + * @param options - Optional request options + * + * @experimental + */ + async cancelTask(taskId: string, options?: RequestOptions): Promise { + return ( + this._server as unknown as { + cancelTask: (params: { taskId: string }, options?: RequestOptions) => Promise; + } + ).cancelTask({ taskId }, options); + } +} diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/experimental/tasks/stores/in-memory.test.ts similarity index 99% rename from src/examples/shared/inMemoryTaskStore.test.ts rename to src/experimental/tasks/stores/in-memory.test.ts index 658e4deb1..f589812ed 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/experimental/tasks/stores/in-memory.test.ts @@ -1,7 +1,7 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; -import { InMemoryTaskStore, InMemoryTaskMessageQueue } from './inMemoryTaskStore.js'; -import { TaskCreationParams, Request } from '../../types.js'; -import { QueuedMessage } from '../../shared/task.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from './in-memory.js'; +import { TaskCreationParams, Request } from '../../../types.js'; +import { QueuedMessage } from '../interfaces.js'; describe('InMemoryTaskStore', () => { let store: InMemoryTaskStore; diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/experimental/tasks/stores/in-memory.ts similarity index 96% rename from src/examples/shared/inMemoryTaskStore.ts rename to src/experimental/tasks/stores/in-memory.ts index 0e3716bdf..a18229f74 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/experimental/tasks/stores/in-memory.ts @@ -1,5 +1,12 @@ -import { Task, Request, RequestId, Result } from '../../types.js'; -import { TaskStore, isTerminal, TaskMessageQueue, QueuedMessage, CreateTaskOptions } from '../../shared/task.js'; +/** + * In-memory implementations of TaskStore and TaskMessageQueue. + * WARNING: These APIs are experimental and may change without notice. + * + * @experimental + */ + +import { Task, Request, RequestId, Result } from '../../../types.js'; +import { TaskStore, isTerminal, TaskMessageQueue, QueuedMessage, CreateTaskOptions } from '../interfaces.js'; import { randomBytes } from 'crypto'; interface StoredTask { @@ -17,6 +24,8 @@ interface StoredTask { * * Note: This is not suitable for production use as all data is lost on restart. * For production, consider implementing TaskStore with a database or distributed cache. + * + * @experimental */ export class InMemoryTaskStore implements TaskStore { private tasks = new Map(); @@ -212,6 +221,8 @@ export class InMemoryTaskStore implements TaskStore { * * Note: This is not suitable for production use in distributed systems. * For production, consider implementing TaskMessageQueue with Redis or other distributed queues. + * + * @experimental */ export class InMemoryTaskMessageQueue implements TaskMessageQueue { private queues = new Map(); diff --git a/src/shared/task-listing.test.ts b/src/experimental/tasks/task-listing.test.ts similarity index 83% rename from src/shared/task-listing.test.ts rename to src/experimental/tasks/task-listing.test.ts index 7fca7d5e4..7259c969e 100644 --- a/src/shared/task-listing.test.ts +++ b/src/experimental/tasks/task-listing.test.ts @@ -1,9 +1,9 @@ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { InMemoryTransport } from '../inMemory.js'; -import { Client } from '../client/index.js'; -import { Server } from '../server/index.js'; -import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; -import { ErrorCode, McpError } from '../types.js'; +import { InMemoryTransport } from '../../inMemory.js'; +import { Client } from '../../client/index.js'; +import { Server } from '../../server/index.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from './stores/in-memory.js'; +import { ErrorCode, McpError } from '../../types.js'; describe('Task Listing with Pagination', () => { let client: Client; @@ -67,7 +67,7 @@ describe('Task Listing with Pagination', () => { }); it('should return empty list when no tasks exist', async () => { - const result = await client.listTasks(); + const result = await client.experimental.tasks.listTasks(); expect(result.tasks).toEqual([]); expect(result.nextCursor).toBeUndefined(); @@ -82,7 +82,7 @@ describe('Task Listing with Pagination', () => { }); } - const result = await client.listTasks(); + const result = await client.experimental.tasks.listTasks(); expect(result.tasks).toHaveLength(3); expect(result.nextCursor).toBeUndefined(); @@ -98,12 +98,12 @@ describe('Task Listing with Pagination', () => { } // Get first page - const page1 = await client.listTasks(); + const page1 = await client.experimental.tasks.listTasks(); expect(page1.tasks).toHaveLength(10); expect(page1.nextCursor).toBeDefined(); // Get second page using cursor - const page2 = await client.listTasks({ cursor: page1.nextCursor }); + const page2 = await client.experimental.tasks.listTasks(page1.nextCursor); expect(page2.tasks).toHaveLength(5); expect(page2.nextCursor).toBeUndefined(); }); @@ -122,7 +122,7 @@ describe('Task Listing with Pagination', () => { const validCursor = allTasks[2].taskId; // Use the cursor - should work even though we don't know its internal structure - const result = await client.listTasks({ cursor: validCursor }); + const result = await client.experimental.tasks.listTasks(validCursor); expect(result.tasks).toHaveLength(2); }); @@ -133,7 +133,7 @@ describe('Task Listing with Pagination', () => { }); // Try to use an invalid cursor - should return -32602 (Invalid params) per MCP spec - await expect(client.listTasks({ cursor: 'invalid-cursor' })).rejects.toSatisfy((error: McpError) => { + await expect(client.experimental.tasks.listTasks('invalid-cursor')).rejects.toSatisfy((error: McpError) => { expect(error).toBeInstanceOf(McpError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Invalid cursor'); @@ -149,11 +149,11 @@ describe('Task Listing with Pagination', () => { }); // Verify it's accessible via tasks/get - const getResult = await client.getTask({ taskId: task.taskId }); + const getResult = await client.experimental.tasks.getTask(task.taskId); expect(getResult.taskId).toBe(task.taskId); // Verify it's also accessible via tasks/list - const listResult = await client.listTasks(); + const listResult = await client.experimental.tasks.listTasks(); expect(listResult.tasks).toHaveLength(1); expect(listResult.tasks[0].taskId).toBe(task.taskId); }); @@ -165,7 +165,7 @@ describe('Task Listing with Pagination', () => { params: { name: 'test-tool' } }); - const result = await client.listTasks(); + const result = await client.experimental.tasks.listTasks(); // The response should have _meta but not include related-task metadata expect(result._meta).toBeDefined(); diff --git a/src/shared/task.test.ts b/src/experimental/tasks/task.test.ts similarity index 97% rename from src/shared/task.test.ts rename to src/experimental/tasks/task.test.ts index 4d21e3dc3..1318c7558 100644 --- a/src/shared/task.test.ts +++ b/src/experimental/tasks/task.test.ts @@ -1,6 +1,6 @@ import { describe, it, expect } from 'vitest'; -import { isTerminal } from './task.js'; -import type { Task } from '../types.js'; +import { isTerminal } from './interfaces.js'; +import type { Task } from '../../types.js'; describe('Task utility functions', () => { describe('isTerminal', () => { diff --git a/src/experimental/tasks/types.ts b/src/experimental/tasks/types.ts new file mode 100644 index 000000000..a3845bae1 --- /dev/null +++ b/src/experimental/tasks/types.ts @@ -0,0 +1,43 @@ +/** + * Re-exports of task-related types from the MCP protocol spec. + * WARNING: These APIs are experimental and may change without notice. + * + * These types are defined in types.ts (matching the protocol spec) and + * re-exported here for convenience when working with experimental task features. + */ + +// Task schemas (Zod) +export { + TaskCreationParamsSchema, + RelatedTaskMetadataSchema, + TaskSchema, + CreateTaskResultSchema, + TaskStatusNotificationParamsSchema, + TaskStatusNotificationSchema, + GetTaskRequestSchema, + GetTaskResultSchema, + GetTaskPayloadRequestSchema, + ListTasksRequestSchema, + ListTasksResultSchema, + CancelTaskRequestSchema, + CancelTaskResultSchema, + ClientTasksCapabilitySchema, + ServerTasksCapabilitySchema +} from '../../types.js'; + +// Task types (inferred from schemas) +export type { + Task, + TaskCreationParams, + RelatedTaskMetadata, + CreateTaskResult, + TaskStatusNotificationParams, + TaskStatusNotification, + GetTaskRequest, + GetTaskResult, + GetTaskPayloadRequest, + ListTasksRequest, + ListTasksResult, + CancelTaskRequest, + CancelTaskResult +} from '../../types.js'; diff --git a/src/integration-tests/taskLifecycle.test.ts b/src/integration-tests/taskLifecycle.test.ts index fb58b7d78..1a569e485 100644 --- a/src/integration-tests/taskLifecycle.test.ts +++ b/src/integration-tests/taskLifecycle.test.ts @@ -16,7 +16,7 @@ import { TaskSchema } from '../types.js'; import { z } from 'zod'; -import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../experimental/tasks/stores/in-memory.js'; import type { TaskRequestOptions } from '../shared/protocol.js'; describe('Task Lifecycle Integration Tests', () => { @@ -51,7 +51,7 @@ describe('Task Lifecycle Integration Tests', () => { ); // Register a long-running tool using registerToolTask - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'long-task', { title: 'Long Running Task', @@ -105,7 +105,7 @@ describe('Task Lifecycle Integration Tests', () => { ); // Register a tool that requires input via elicitation - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'input-task', { title: 'Input Required Task', @@ -350,8 +350,8 @@ describe('Task Lifecycle Integration Tests', () => { let task = await taskStore.getTask(taskId); expect(task?.status).toBe('working'); - // Cancel the task via client.cancelTask - per spec, returns Result & Task - const cancelResult = await client.cancelTask({ taskId }); + // Cancel the task via client.experimental.tasks.cancelTask - per spec, returns Result & Task + const cancelResult = await client.experimental.tasks.cancelTask(taskId); // Verify the cancel response includes the cancelled task (per MCP spec CancelTaskResult is Result & Task) expect(cancelResult.taskId).toBe(taskId); @@ -403,7 +403,7 @@ describe('Task Lifecycle Integration Tests', () => { expect(task?.status).toBe('completed'); // Try to cancel via tasks/cancel request (should fail with -32602) - await expect(client.cancelTask({ taskId })).rejects.toSatisfy((error: McpError) => { + await expect(client.experimental.tasks.cancelTask(taskId)).rejects.toSatisfy((error: McpError) => { expect(error).toBeInstanceOf(McpError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Cannot cancel task in terminal status'); @@ -417,7 +417,7 @@ describe('Task Lifecycle Integration Tests', () => { describe('Multiple Queued Messages', () => { it('should deliver multiple queued messages in order', async () => { // Register a tool that sends multiple server requests during execution - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'multi-request-task', { title: 'Multi Request Task', @@ -799,7 +799,7 @@ describe('Task Lifecycle Integration Tests', () => { await client.connect(transport); // Try to get non-existent task via tasks/get request - await expect(client.getTask({ taskId: 'non-existent-task-id' })).rejects.toSatisfy((error: McpError) => { + await expect(client.experimental.tasks.getTask('non-existent-task-id')).rejects.toSatisfy((error: McpError) => { expect(error).toBeInstanceOf(McpError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Task not found'); @@ -819,7 +819,7 @@ describe('Task Lifecycle Integration Tests', () => { await client.connect(transport); // Try to cancel non-existent task via tasks/cancel request - await expect(client.cancelTask({ taskId: 'non-existent-task-id' })).rejects.toSatisfy((error: McpError) => { + await expect(client.experimental.tasks.cancelTask('non-existent-task-id')).rejects.toSatisfy((error: McpError) => { expect(error).toBeInstanceOf(McpError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Task not found'); @@ -908,7 +908,7 @@ describe('Task Lifecycle Integration Tests', () => { describe('Task Cancellation with Queued Messages', () => { it('should clear queue and deliver no messages when task is cancelled before tasks/result', async () => { // Register a tool that queues messages but doesn't complete immediately - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'cancellable-task', { title: 'Cancellable Task', @@ -1105,7 +1105,7 @@ describe('Task Lifecycle Integration Tests', () => { describe('Continuous Message Delivery', () => { it('should deliver messages immediately while tasks/result is blocking', async () => { // Register a tool that queues messages over time - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'streaming-task', { title: 'Streaming Task', @@ -1322,7 +1322,7 @@ describe('Task Lifecycle Integration Tests', () => { describe('Terminal Task with Queued Messages', () => { it('should deliver queued messages followed by final result for terminal task', async () => { // Register a tool that completes quickly and queues messages before completion - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'quick-complete-task', { title: 'Quick Complete Task', diff --git a/src/server/index.test.ts b/src/server/index.test.ts index b1fb8a77a..00593bf9c 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -24,7 +24,7 @@ import { } from '../types.js'; import { Server } from './index.js'; import { McpServer } from './mcp.js'; -import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../experimental/tasks/stores/in-memory.js'; import { CallToolRequestSchema, CallToolResultSchema } from '../types.js'; import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js'; import type { AnyObjectSchema } from './zod-compat.js'; @@ -2017,7 +2017,7 @@ describe('Task-based execution', () => { ); // Register a tool using registerToolTask - server.registerToolTask( + server.experimental.tasks.registerToolTask( 'test-tool', { description: 'A test tool', @@ -2078,7 +2078,7 @@ describe('Task-based execution', () => { // Use callToolStream to create a task and capture the task ID let taskId: string | undefined; - const stream = client.callToolStream({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + const stream = client.experimental.tasks.callToolStream({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 } @@ -2096,12 +2096,12 @@ describe('Task-based execution', () => { await new Promise(resolve => setTimeout(resolve, 50)); // Verify we can retrieve the task - const task = await client.getTask({ taskId: taskId! }); + const task = await client.experimental.tasks.getTask(taskId!); expect(task).toBeDefined(); expect(task.status).toBe('completed'); // Verify we can retrieve the result - const result = await client.getTaskResult({ taskId: taskId! }, CallToolResultSchema); + const result = await client.experimental.tasks.getTaskResult(taskId!, CallToolResultSchema); expect(result.content).toEqual([{ type: 'text', text: 'Tool executed successfully!' }]); // Cleanup @@ -2168,7 +2168,7 @@ describe('Task-based execution', () => { // Try to get a task when server doesn't have TaskStore // The server will return a "Method not found" error - await expect(client.getTask({ taskId: 'non-existent' })).rejects.toThrow('Method not found'); + await expect(client.experimental.tasks.getTask('non-existent')).rejects.toThrow('Method not found'); }); test('should automatically attach related-task metadata to nested requests during tool execution', async () => { @@ -2239,7 +2239,7 @@ describe('Task-based execution', () => { }); // Register a tool using registerToolTask that makes a nested elicitation request - server.registerToolTask( + server.experimental.tasks.registerToolTask( 'collect-info', { description: 'Collects user info via elicitation', @@ -2305,7 +2305,7 @@ describe('Task-based execution', () => { // Call tool WITH task creation using callToolStream to capture task ID let taskId: string | undefined; - const stream = client.callToolStream({ name: 'collect-info', arguments: {} }, CallToolResultSchema, { + const stream = client.experimental.tasks.callToolStream({ name: 'collect-info', arguments: {} }, CallToolResultSchema, { task: { ttl: 60000 } @@ -2326,7 +2326,7 @@ describe('Task-based execution', () => { expect(capturedElicitRequest).toBeDefined(); // Verify tool result was correct - const result = await client.getTaskResult({ taskId: taskId! }, CallToolResultSchema); + const result = await client.experimental.tasks.getTaskResult(taskId!, CallToolResultSchema); expect(result.content).toEqual([ { type: 'text', @@ -2426,7 +2426,7 @@ describe('Task-based execution', () => { const taskId = createTaskResult.task.taskId; // Verify task was created - const task = await server.getTask({ taskId }); + const task = await server.experimental.tasks.getTask(taskId); expect(task.status).toBe('completed'); }); @@ -2503,7 +2503,7 @@ describe('Task-based execution', () => { const taskId = createTaskResult.task.taskId; // Query task - const task = await server.getTask({ taskId }); + const task = await server.experimental.tasks.getTask(taskId); expect(task).toBeDefined(); expect(task.taskId).toBe(taskId); expect(task.status).toBe('completed'); @@ -2585,7 +2585,7 @@ describe('Task-based execution', () => { const taskId = createTaskResult.task.taskId; // Query result - const result = await server.getTaskResult({ taskId }, ElicitResultSchema); + const result = await server.experimental.tasks.getTaskResult(taskId, ElicitResultSchema); expect(result.action).toBe('accept'); expect(result.content).toEqual({ username: 'result-user', confirmed: true }); }); @@ -2679,7 +2679,7 @@ describe('Task-based execution', () => { } // Query task list - const taskList = await server.listTasks(); + const taskList = await server.experimental.tasks.listTasks(); expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); for (const taskId of createdTaskIds) { expect(taskList.tasks).toContainEqual( @@ -2715,7 +2715,7 @@ describe('Task-based execution', () => { ); // Register a tool using registerToolTask with variable delay - server.registerToolTask( + server.experimental.tasks.registerToolTask( 'async-tool', { description: 'An async test tool', @@ -2791,22 +2791,22 @@ describe('Task-based execution', () => { await new Promise(resolve => setTimeout(resolve, 50)); // Get all task IDs from the task list - const taskList = await client.listTasks(); + const taskList = await client.experimental.tasks.listTasks(); expect(taskList.tasks.length).toBeGreaterThanOrEqual(4); const taskIds = taskList.tasks.map(t => t.taskId); // Verify all tasks completed successfully for (let i = 0; i < taskIds.length; i++) { - const task = await client.getTask({ taskId: taskIds[i] }); + const task = await client.experimental.tasks.getTask(taskIds[i]); expect(task.status).toBe('completed'); expect(task.taskId).toBe(taskIds[i]); - const result = await client.getTaskResult({ taskId: taskIds[i] }, CallToolResultSchema); + const result = await client.experimental.tasks.getTaskResult(taskIds[i], CallToolResultSchema); expect(result.content).toEqual([{ type: 'text', text: `Completed task ${i + 1}` }]); } // Verify listTasks returns all tasks - const finalTaskList = await client.listTasks(); + const finalTaskList = await client.experimental.tasks.listTasks(); for (const taskId of taskIds) { expect(finalTaskList.tasks).toContainEqual(expect.objectContaining({ taskId })); } @@ -2873,7 +2873,7 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Try to query a task that doesn't exist - await expect(client.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + await expect(client.experimental.tasks.getTask('non-existent-task')).rejects.toThrow(); }); test('should throw error when server queries non-existent task from client', async () => { @@ -2925,7 +2925,7 @@ describe('Task-based execution', () => { await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Try to query a task that doesn't exist on client - await expect(server.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + await expect(server.experimental.tasks.getTask('non-existent-task')).rejects.toThrow(); }); }); }); @@ -3033,8 +3033,8 @@ test('should respect client task capabilities', async () => { expect(createTaskResult.task.taskId).toBeDefined(); const taskId = createTaskResult.task.taskId; - await expect(server.listTasks()).resolves.not.toThrow(); - await expect(server.getTask({ taskId })).resolves.not.toThrow(); + await expect(server.experimental.tasks.listTasks()).resolves.not.toThrow(); + await expect(server.experimental.tasks.getTask(taskId)).resolves.not.toThrow(); // This should throw because client doesn't support task creation for sampling/createMessage await expect( diff --git a/src/server/index.ts b/src/server/index.ts index 23061bf98..dfbb2a2a3 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,5 +1,4 @@ import { mergeCapabilities, Protocol, type NotificationOptions, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; -import { ResponseMessage } from '../shared/responseMessage.js'; import { type ClientCapabilities, type CreateMessageRequest, @@ -42,7 +41,6 @@ import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js'; import type { JsonSchemaType, jsonSchemaValidator } from '../validation/types.js'; import { AnyObjectSchema, - AnySchema, getObjectShape, isZ4Schema, safeParse, @@ -51,6 +49,8 @@ import { type ZodV4Internal } from './zod-compat.js'; import { RequestHandlerExtra } from '../shared/protocol.js'; +import { ExperimentalServerTasks } from '../experimental/tasks/server.js'; +import { assertToolsCallTaskCapability, assertClientRequestTaskCapability } from '../experimental/tasks/helpers.js'; export type ServerOptions = ProtocolOptions & { /** @@ -131,6 +131,7 @@ export class Server< private _capabilities: ServerCapabilities; private _instructions?: string; private _jsonSchemaValidator: jsonSchemaValidator; + private _experimental?: { tasks: ExperimentalServerTasks }; /** * Callback for when initialization has fully completed (i.e., the client has sent an `initialized` notification). @@ -166,6 +167,22 @@ export class Server< } } + /** + * Access experimental features. + * + * WARNING: These APIs are experimental and may change without notice. + * + * @experimental + */ + get experimental(): { tasks: ExperimentalServerTasks } { + if (!this._experimental) { + this._experimental = { + tasks: new ExperimentalServerTasks(this) + }; + } + return this._experimental; + } + // Map log levels by session id private _loggingLevels = new Map(); @@ -399,29 +416,7 @@ export class Server< } protected assertTaskCapability(method: string): void { - if (!this._clientCapabilities?.tasks?.requests) { - throw new Error(`Client does not support task creation (required for ${method})`); - } - - const requests = this._clientCapabilities.tasks.requests; - - switch (method) { - case 'sampling/createMessage': - if (!requests.sampling?.createMessage) { - throw new Error(`Client does not support task creation for sampling/createMessage (required for ${method})`); - } - break; - - case 'elicitation/create': - if (!requests.elicitation?.create) { - throw new Error(`Client does not support task creation for elicitation/create (required for ${method})`); - } - break; - - default: - // Method doesn't support tasks, which is fine - no error - break; - } + assertClientRequestTaskCapability(this._clientCapabilities?.tasks?.requests, method, 'Client'); } protected assertTaskHandlerCapability(method: string): void { @@ -431,23 +426,7 @@ export class Server< return; } - if (!this._capabilities.tasks?.requests) { - throw new Error(`Server does not support task creation (required for ${method})`); - } - - const requests = this._capabilities.tasks.requests; - - switch (method) { - case 'tools/call': - if (!requests.tools?.call) { - throw new Error(`Server does not support task creation for tools/call (required for ${method})`); - } - break; - - default: - // Method doesn't support tasks, which is fine - no error - break; - } + assertToolsCallTaskCapability(this._capabilities.tasks?.requests, method, 'Server'); } private async _oninitialize(request: InitializeRequest): Promise { @@ -484,47 +463,6 @@ export class Server< return this._capabilities; } - /** - * Sends a request and returns an AsyncGenerator that yields response messages. - * The generator is guaranteed to end with either a 'result' or 'error' message. - * - * This method provides streaming access to request processing, allowing you to - * observe intermediate task status updates for task-augmented requests. - * - * @example - * ```typescript - * const stream = server.requestStream(request, resultSchema, options); - * for await (const message of stream) { - * switch (message.type) { - * case 'taskCreated': - * console.log('Task created:', message.task.taskId); - * break; - * case 'taskStatus': - * console.log('Task status:', message.task.status); - * break; - * case 'result': - * console.log('Final result:', message.result); - * break; - * case 'error': - * console.error('Error:', message.error); - * break; - * } - * } - * ``` - * - * @param request - The request to send - * @param resultSchema - Zod schema for validating the result - * @param options - Optional request options (timeout, signal, task creation params, etc.) - * @returns AsyncGenerator that yields ResponseMessage objects - */ - requestStream( - request: ServerRequest | RequestT, - resultSchema: T, - options?: RequestOptions - ): AsyncGenerator>, void, void> { - return super.requestStream(request, resultSchema, options); - } - async ping() { return this.request({ method: 'ping' }, EmptyResultSchema); } diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index c3fd30cc3..2ad40ba5a 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -21,7 +21,7 @@ import { } from '../types.js'; import { completable } from './completable.js'; import { McpServer, ResourceTemplate } from './mcp.js'; -import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; +import { InMemoryTaskStore } from '../experimental/tasks/stores/in-memory.js'; import { zodTestMatrix, type ZodMatrixEntry } from '../shared/zodTestMatrix.js'; function createLatch() { @@ -1803,7 +1803,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Register a tool with execution.taskSupport - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'task-tool', { description: 'A tool with task support', @@ -1872,7 +1872,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Register a tool with execution.taskSupport optional - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'optional-task-tool', { description: 'A tool with optional task support', @@ -5874,7 +5874,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { ); // Register a task-based tool with taskSupport "required" - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'long-running-task', { description: 'A long running task', @@ -5979,7 +5979,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { ); // Register a task-based tool with taskSupport "optional" - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'optional-task', { description: 'An optional task', @@ -6087,7 +6087,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { ); // Register a task-based tool with taskSupport "required" - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'task-tool', { description: 'A task tool', @@ -6207,7 +6207,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { ); // Register a task-based tool that fails - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'failing-task', { description: 'A failing task', @@ -6313,7 +6313,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { ); // Register a task-based tool that gets cancelled - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'cancelled-task', { description: 'A task that gets cancelled', @@ -6397,7 +6397,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Attempt to register a task-based tool with taskSupport "forbidden" (cast to bypass type checking) expect(() => { - mcpServer.registerToolTask( + mcpServer.experimental.tasks.registerToolTask( 'invalid-task', { description: 'A task with forbidden support', diff --git a/src/server/mcp.ts b/src/server/mcp.ts index a727a4f33..68a81764a 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -46,7 +46,6 @@ import { ToolAnnotations, LoggingMessageNotification, CreateTaskResult, - GetTaskResult, Result, CompleteRequestPrompt, CompleteRequestResourceTemplate, @@ -57,10 +56,12 @@ import { } from '../types.js'; import { isCompletable, getCompleter } from './completable.js'; import { UriTemplate, Variables } from '../shared/uriTemplate.js'; -import { RequestHandlerExtra, RequestTaskStore } from '../shared/protocol.js'; +import { RequestHandlerExtra } from '../shared/protocol.js'; import { Transport } from '../shared/transport.js'; import { validateAndWarnToolName } from '../shared/toolNameValidation.js'; +import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcp-server.js'; +import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; /** * High-level MCP server that provides a simpler API for working with resources, tools, and prompts. @@ -79,11 +80,28 @@ export class McpServer { } = {}; private _registeredTools: { [name: string]: RegisteredTool } = {}; private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; + private _experimental?: { tasks: ExperimentalMcpServerTasks }; constructor(serverInfo: Implementation, options?: ServerOptions) { this.server = new Server(serverInfo, options); } + /** + * Access experimental features. + * + * WARNING: These APIs are experimental and may change without notice. + * + * @experimental + */ + get experimental(): { tasks: ExperimentalMcpServerTasks } { + if (!this._experimental) { + this._experimental = { + tasks: new ExperimentalMcpServerTasks(this) + }; + } + return this._experimental; + } + /** * Attaches to the given transport, starts it, and starts listening for messages. * @@ -1043,74 +1061,6 @@ export class McpServer { ); } - /** - * Registers a task-based tool with a config object and callback. - */ - registerToolTask( - name: string, - config: { - title?: string; - description?: string; - outputSchema?: OutputArgs; - annotations?: ToolAnnotations; - execution?: TaskToolExecution; - _meta?: Record; - }, - handler: ToolTaskHandler - ): RegisteredTool; - - /** - * Registers a task-based tool with a config object and callback. - */ - registerToolTask( - name: string, - config: { - title?: string; - description?: string; - inputSchema: InputArgs; - outputSchema?: OutputArgs; - annotations?: ToolAnnotations; - execution?: TaskToolExecution; - _meta?: Record; - }, - handler: ToolTaskHandler - ): RegisteredTool; - - registerToolTask< - InputArgs extends undefined | ZodRawShapeCompat | AnySchema, - OutputArgs extends undefined | ZodRawShapeCompat | AnySchema - >( - name: string, - config: { - title?: string; - description?: string; - inputSchema?: InputArgs; - outputSchema?: OutputArgs; - annotations?: ToolAnnotations; - execution?: TaskToolExecution; - _meta?: Record; - }, - handler: ToolTaskHandler - ): RegisteredTool { - // Validate that taskSupport is not 'forbidden' for task-based tools - const execution: ToolExecution = { taskSupport: 'required', ...config.execution }; - if (execution.taskSupport === 'forbidden') { - throw new Error(`Cannot register task-based tool '${name}' with taskSupport 'forbidden'. Use registerTool() instead.`); - } - - return this._createRegisteredTool( - name, - config.title, - config.description, - config.inputSchema, - config.outputSchema, - config.annotations, - execution, - config._meta, - handler - ); - } - /** * Registers a zero-argument prompt `name`, which will run the given function when the client calls it. * @deprecated Use `registerPrompt` instead. @@ -1326,47 +1276,11 @@ export type ToolCallback; -export interface CreateTaskRequestHandlerExtra extends RequestHandlerExtra { - taskStore: RequestTaskStore; -} - -export interface TaskRequestHandlerExtra extends RequestHandlerExtra { - taskId: string; - taskStore: RequestTaskStore; -} - -export type CreateTaskRequestHandler< - SendResultT extends Result, - Args extends undefined | ZodRawShapeCompat | AnySchema = undefined -> = BaseToolCallback; - -export type TaskRequestHandler< - SendResultT extends Result, - Args extends undefined | ZodRawShapeCompat | AnySchema = undefined -> = BaseToolCallback; - -export interface ToolTaskHandler { - createTask: CreateTaskRequestHandler; - getTask: TaskRequestHandler; - getTaskResult: TaskRequestHandler; -} - -/** - * Supertype for tool handler callbacks registered with Server.registerTool() and Server.registerToolTask(). - */ -export type AnyToolCallback = - | ToolCallback - | TaskRequestHandler; - /** * Supertype that can handle both regular tools (simple callback) and task-based tools (task handler object). */ export type AnyToolHandler = ToolCallback | ToolTaskHandler; -export type TaskToolExecution = Omit & { - taskSupport: TaskSupport extends 'forbidden' | undefined ? never : TaskSupport; -}; - export type RegisteredTool = { title?: string; description?: string; diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index f56c806fc..68f843156 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -16,13 +16,13 @@ import { } from '../types.js'; import { Protocol, mergeCapabilities } from './protocol.js'; import { Transport, TransportSendOptions } from './transport.js'; -import { TaskStore, TaskMessageQueue, QueuedMessage, QueuedNotification, QueuedRequest } from './task.js'; +import { TaskStore, TaskMessageQueue, QueuedMessage, QueuedNotification, QueuedRequest } from '../experimental/tasks/interfaces.js'; import { MockInstance, vi } from 'vitest'; import { JSONRPCResponse, JSONRPCRequest, JSONRPCError } from '../types.js'; import { ErrorMessage, ResponseMessage, toArrayAsync } from './responseMessage.js'; -import { InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; +import { InMemoryTaskMessageQueue } from '../experimental/tasks/stores/in-memory.js'; -// Type helper for accessing private Protocol properties in tests +// Type helper for accessing private/protected Protocol properties in tests interface TestProtocol { _taskMessageQueue?: TaskMessageQueue; _requestResolvers: Map void>; @@ -30,6 +30,10 @@ interface TestProtocol { _taskProgressTokens: Map; _clearTaskQueue: (taskId: string, sessionId?: string) => Promise; requestTaskStore: (request: Request, authInfo: unknown) => TaskStore; + // Protected task methods (exposed for testing) + listTasks: (params?: { cursor?: string }) => Promise<{ tasks: Task[]; nextCursor?: string }>; + cancelTask: (params: { taskId: string }) => Promise; + requestStream: (request: Request, schema: ZodType, options?: unknown) => AsyncGenerator>; } // Mock Transport class @@ -1487,7 +1491,7 @@ describe('Task-based execution', () => { it('should call listTasks method from client side', async () => { await protocol.connect(transport); - const listTasksPromise = protocol.listTasks(); + const listTasksPromise = (protocol as unknown as TestProtocol).listTasks(); // Simulate server response setTimeout(() => { @@ -1527,7 +1531,7 @@ describe('Task-based execution', () => { it('should call listTasks with cursor from client side', async () => { await protocol.connect(transport); - const listTasksPromise = protocol.listTasks({ cursor: 'task-10' }); + const listTasksPromise = (protocol as unknown as TestProtocol).listTasks({ cursor: 'task-10' }); // Simulate server response setTimeout(() => { @@ -1713,7 +1717,7 @@ describe('Task-based execution', () => { it('should call cancelTask method from client side', async () => { await protocol.connect(transport); - const deleteTaskPromise = protocol.cancelTask({ taskId: 'task-to-delete' }); + const deleteTaskPromise = (protocol as unknown as TestProtocol).cancelTask({ taskId: 'task-to-delete' }); // Simulate server response - per MCP spec, CancelTaskResult is Result & Task setTimeout(() => { @@ -4458,7 +4462,10 @@ describe('requestStream() method', () => { // Start the request stream const streamPromise = (async () => { const messages = []; - const stream = protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema); + const stream = (protocol as unknown as TestProtocol).requestStream( + { method: 'tools/call', params: { name: 'test', arguments: {} } }, + CallToolResultSchema + ); for await (const message of stream) { messages.push(message); } @@ -4498,7 +4505,10 @@ describe('requestStream() method', () => { // Start the request stream const streamPromise = (async () => { const messages = []; - const stream = protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema); + const stream = (protocol as unknown as TestProtocol).requestStream( + { method: 'tools/call', params: { name: 'test', arguments: {} } }, + CallToolResultSchema + ); for await (const message of stream) { messages.push(message); } @@ -4545,9 +4555,13 @@ describe('requestStream() method', () => { // Start the request stream with already-aborted signal const messages = []; - const stream = protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { - signal: abortController.signal - }); + const stream = (protocol as unknown as TestProtocol).requestStream( + { method: 'tools/call', params: { name: 'test', arguments: {} } }, + CallToolResultSchema, + { + signal: abortController.signal + } + ); for await (const message of stream) { messages.push(message); } @@ -4573,7 +4587,10 @@ describe('requestStream() method', () => { await protocol.connect(transport); const messagesPromise = toArrayAsync( - protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + (protocol as unknown as TestProtocol).requestStream( + { method: 'tools/call', params: { name: 'test', arguments: {} } }, + CallToolResultSchema + ) ); // Simulate server error response @@ -4612,9 +4629,13 @@ describe('requestStream() method', () => { await protocol.connect(transport); const messagesPromise = toArrayAsync( - protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { - timeout: 100 - }) + (protocol as unknown as TestProtocol).requestStream( + { method: 'tools/call', params: { name: 'test', arguments: {} } }, + CallToolResultSchema, + { + timeout: 100 + } + ) ); // Advance time to trigger timeout @@ -4650,9 +4671,13 @@ describe('requestStream() method', () => { // Collect messages const messages = await toArrayAsync( - protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { - signal: abortController.signal - }) + (protocol as unknown as TestProtocol).requestStream( + { method: 'tools/call', params: { name: 'test', arguments: {} } }, + CallToolResultSchema, + { + signal: abortController.signal + } + ) ); // Verify error is terminal and last message @@ -4675,7 +4700,10 @@ describe('requestStream() method', () => { await protocol.connect(transport); const messagesPromise = toArrayAsync( - protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + (protocol as unknown as TestProtocol).requestStream( + { method: 'tools/call', params: { name: 'test', arguments: {} } }, + CallToolResultSchema + ) ); // Simulate server error response @@ -4724,7 +4752,10 @@ describe('requestStream() method', () => { await protocol.connect(transport); const messagesPromise = toArrayAsync( - protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + (protocol as unknown as TestProtocol).requestStream( + { method: 'tools/call', params: { name: 'test', arguments: {} } }, + CallToolResultSchema + ) ); // Simulate task creation response @@ -4784,7 +4815,10 @@ describe('requestStream() method', () => { transport.send = vi.fn().mockRejectedValue(new Error('Network error')); const messages = await toArrayAsync( - protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + (protocol as unknown as TestProtocol).requestStream( + { method: 'tools/call', params: { name: 'test', arguments: {} } }, + CallToolResultSchema + ) ); // Verify error is terminal and last message @@ -4806,7 +4840,10 @@ describe('requestStream() method', () => { await protocol.connect(transport); const messagesPromise = toArrayAsync( - protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + (protocol as unknown as TestProtocol).requestStream( + { method: 'tools/call', params: { name: 'test', arguments: {} } }, + CallToolResultSchema + ) ); // Simulate server error response diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 31532d85a..cac95fcc4 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -45,7 +45,7 @@ import { } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; -import { isTerminal, TaskStore, TaskMessageQueue, QueuedMessage, CreateTaskOptions } from './task.js'; +import { isTerminal, TaskStore, TaskMessageQueue, QueuedMessage, CreateTaskOptions } from '../experimental/tasks/interfaces.js'; import { getMethodLiteral, parseWithCompat } from '../server/zod-json-schema-compat.js'; import { ResponseMessage } from './responseMessage.js'; @@ -956,8 +956,10 @@ export abstract class Protocol( + protected async *requestStream( request: SendRequestT, resultSchema: T, options?: RequestOptions @@ -1199,16 +1201,20 @@ export abstract class Protocol { + protected async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { // @ts-expect-error SendRequestT cannot directly contain GetTaskRequest, but we ensure all type instantiations contain it anyways return this.request({ method: 'tasks/get', params }, GetTaskResultSchema, options); } /** * Retrieves the result of a completed task. + * + * @experimental Use `client.experimental.tasks.getTaskResult()` to access this method. */ - async getTaskResult( + protected async getTaskResult( params: GetTaskPayloadRequest['params'], resultSchema: T, options?: RequestOptions @@ -1219,16 +1225,20 @@ export abstract class Protocol> { + protected async listTasks(params?: { cursor?: string }, options?: RequestOptions): Promise> { // @ts-expect-error SendRequestT cannot directly contain ListTasksRequest, but we ensure all type instantiations contain it anyways return this.request({ method: 'tasks/list', params }, ListTasksResultSchema, options); } /** * Cancels a specific task. + * + * @experimental Use `client.experimental.tasks.cancelTask()` to access this method. */ - async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise> { + protected async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise> { // @ts-expect-error SendRequestT cannot directly contain CancelTaskRequest, but we ensure all type instantiations contain it anyways return this.request({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); } From ad3fee3122930895d846637a9a97ffc1ab0fc893 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Wed, 26 Nov 2025 21:54:59 +0000 Subject: [PATCH 83/84] refactor: decouple callTool() from experimental callToolStream() Restore callTool() to its original implementation instead of delegating to experimental.tasks.callToolStream(). This aligns with Python SDK's approach where call_tool() is task-unaware and call_tool_as_task() is the explicit experimental method. Changes: - Add guard for taskSupport: 'required' tools with clear error message - Restore original output schema validation logic - Add _cachedRequiredTaskTools to track required-only task tools - Remove unused takeResult import Tools with taskSupport: 'optional' work normally with callTool() since the server returns CallToolResult. Only 'required' tools need the experimental API. --- src/client/index.ts | 68 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 6 deletions(-) diff --git a/src/client/index.ts b/src/client/index.ts index d6dfe82d9..0fb6cdcf3 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,6 +1,5 @@ import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; import type { Transport } from '../shared/transport.js'; -import { takeResult } from '../shared/responseMessage.js'; import { type CallToolRequest, @@ -203,6 +202,7 @@ export class Client< private _jsonSchemaValidator: jsonSchemaValidator; private _cachedToolOutputValidators: Map> = new Map(); private _cachedKnownTaskTools: Set = new Set(); + private _cachedRequiredTaskTools: Set = new Set(); private _experimental?: { tasks: ExperimentalClientTasks }; /** @@ -645,13 +645,57 @@ export class Client< * * For task-based execution with streaming behavior, use client.experimental.tasks.callToolStream() instead. */ - async callTool( + async callTool( params: CallToolRequest['params'], - resultSchema: T = CallToolResultSchema as T, + resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, options?: RequestOptions - ): Promise> { - // Use experimental.tasks.callToolStream for implementation (temporary dependency) - return await takeResult(this.experimental.tasks.callToolStream(params, resultSchema, options)); + ) { + // Guard: required-task tools need experimental API + if (this.isToolTaskRequired(params.name)) { + throw new McpError( + ErrorCode.InvalidRequest, + `Tool "${params.name}" requires task-based execution. Use client.experimental.tasks.callToolStream() instead.` + ); + } + + const result = await this.request({ method: 'tools/call', params }, resultSchema, options); + + // Check if the tool has an outputSchema + const validator = this.getToolOutputValidator(params.name); + if (validator) { + // If tool has outputSchema, it MUST return structuredContent (unless it's an error) + if (!result.structuredContent && !result.isError) { + throw new McpError( + ErrorCode.InvalidRequest, + `Tool ${params.name} has an output schema but did not return structured content` + ); + } + + // Only validate structured content if present (not when there's an error) + if (result.structuredContent) { + try { + // Validate the structured content against the schema + const validationResult = validator(result.structuredContent); + + if (!validationResult.valid) { + throw new McpError( + ErrorCode.InvalidParams, + `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` + ); + } + } catch (error) { + if (error instanceof McpError) { + throw error; + } + throw new McpError( + ErrorCode.InvalidParams, + `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + } + + return result; } private isToolTask(toolName: string): boolean { @@ -662,6 +706,14 @@ export class Client< return this._cachedKnownTaskTools.has(toolName); } + /** + * Check if a tool requires task-based execution. + * Unlike isToolTask which includes 'optional' tools, this only checks for 'required'. + */ + private isToolTaskRequired(toolName: string): boolean { + return this._cachedRequiredTaskTools.has(toolName); + } + /** * Cache validators for tool output schemas. * Called after listTools() to pre-compile validators for better performance. @@ -669,6 +721,7 @@ export class Client< private cacheToolMetadata(tools: Tool[]): void { this._cachedToolOutputValidators.clear(); this._cachedKnownTaskTools.clear(); + this._cachedRequiredTaskTools.clear(); for (const tool of tools) { // If the tool has an outputSchema, create and cache the validator @@ -682,6 +735,9 @@ export class Client< if (taskSupport === 'required' || taskSupport === 'optional') { this._cachedKnownTaskTools.add(tool.name); } + if (taskSupport === 'required') { + this._cachedRequiredTaskTools.add(tool.name); + } } } From 0e89248fbc0ef56ef63e4035f543bd366df90b29 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 27 Nov 2025 11:12:13 +0000 Subject: [PATCH 84/84] test: add callToolStream tests for non-task tools --- src/client/index.test.ts | 463 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 463 insertions(+) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index e161403fc..4efd2adac 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -3088,6 +3088,469 @@ test('should validate structured output in callToolStream()', async () => { await server.close(); }); +test('callToolStream() should yield error when structuredContent does not match schema', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + result: { type: 'string' }, + count: { type: 'number' } + }, + required: ['result', 'count'], + additionalProperties: false + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async () => { + // Return invalid structured content (count is string instead of number) + return { + structuredContent: { result: 'success', count: 'not a number' } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // List tools to cache the schemas + await client.listTools(); + + const stream = client.experimental.tasks.callToolStream({ name: 'test-tool', arguments: {} }); + + const messages = []; + for await (const message of stream) { + messages.push(message); + } + + expect(messages.length).toBe(1); + expect(messages[0].type).toBe('error'); + if (messages[0].type === 'error') { + expect(messages[0].error.message).toMatch(/Structured content does not match the tool's output schema/); + } + + await client.close(); + await server.close(); +}); + +test('callToolStream() should yield error when tool with outputSchema returns no structuredContent', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + result: { type: 'string' } + }, + required: ['result'] + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async () => { + return { + content: [{ type: 'text', text: 'This should be structured content' }] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await client.listTools(); + + const stream = client.experimental.tasks.callToolStream({ name: 'test-tool', arguments: {} }); + + const messages = []; + for await (const message of stream) { + messages.push(message); + } + + expect(messages.length).toBe(1); + expect(messages[0].type).toBe('error'); + if (messages[0].type === 'error') { + expect(messages[0].error.message).toMatch(/Tool test-tool has an output schema but did not return structured content/); + } + + await client.close(); + await server.close(); +}); + +test('callToolStream() should handle tools without outputSchema normally', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async () => { + return { + content: [{ type: 'text', text: 'Normal response' }] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await client.listTools(); + + const stream = client.experimental.tasks.callToolStream({ name: 'test-tool', arguments: {} }); + + const messages = []; + for await (const message of stream) { + messages.push(message); + } + + expect(messages.length).toBe(1); + expect(messages[0].type).toBe('result'); + if (messages[0].type === 'result') { + expect(messages[0].result.content).toEqual([{ type: 'text', text: 'Normal response' }]); + } + + await client.close(); + await server.close(); +}); + +test('callToolStream() should handle complex JSON schema validation', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'complex-tool', + description: 'A tool with complex schema', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + name: { type: 'string', minLength: 3 }, + age: { type: 'integer', minimum: 0, maximum: 120 }, + active: { type: 'boolean' }, + tags: { + type: 'array', + items: { type: 'string' }, + minItems: 1 + }, + metadata: { + type: 'object', + properties: { + created: { type: 'string' } + }, + required: ['created'] + } + }, + required: ['name', 'age', 'active', 'tags', 'metadata'], + additionalProperties: false + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async () => { + return { + structuredContent: { + name: 'John Doe', + age: 30, + active: true, + tags: ['user', 'admin'], + metadata: { + created: '2023-01-01T00:00:00Z' + } + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await client.listTools(); + + const stream = client.experimental.tasks.callToolStream({ name: 'complex-tool', arguments: {} }); + + const messages = []; + for await (const message of stream) { + messages.push(message); + } + + expect(messages.length).toBe(1); + expect(messages[0].type).toBe('result'); + if (messages[0].type === 'result') { + expect(messages[0].result.structuredContent).toBeDefined(); + const structuredContent = messages[0].result.structuredContent as { name: string; age: number }; + expect(structuredContent.name).toBe('John Doe'); + expect(structuredContent.age).toBe(30); + } + + await client.close(); + await server.close(); +}); + +test('callToolStream() should yield error with additional properties when not allowed', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'strict-tool', + description: 'A tool with strict schema', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + name: { type: 'string' } + }, + required: ['name'], + additionalProperties: false + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async () => { + return { + structuredContent: { + name: 'John', + extraField: 'not allowed' + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await client.listTools(); + + const stream = client.experimental.tasks.callToolStream({ name: 'strict-tool', arguments: {} }); + + const messages = []; + for await (const message of stream) { + messages.push(message); + } + + expect(messages.length).toBe(1); + expect(messages[0].type).toBe('error'); + if (messages[0].type === 'error') { + expect(messages[0].error.message).toMatch(/Structured content does not match the tool's output schema/); + } + + await client.close(); + await server.close(); +}); + +test('callToolStream() should not validate structuredContent when isError is true', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + result: { type: 'string' } + }, + required: ['result'] + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async () => { + // Return isError with content (no structuredContent) - should NOT trigger validation error + return { + isError: true, + content: [{ type: 'text', text: 'Something went wrong' }] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await client.listTools(); + + const stream = client.experimental.tasks.callToolStream({ name: 'test-tool', arguments: {} }); + + const messages = []; + for await (const message of stream) { + messages.push(message); + } + + // Should have received result (not error), with isError flag set + expect(messages.length).toBe(1); + expect(messages[0].type).toBe('result'); + if (messages[0].type === 'result') { + expect(messages[0].result.isError).toBe(true); + expect(messages[0].result.content).toEqual([{ type: 'text', text: 'Something went wrong' }]); + } + + await client.close(); + await server.close(); +}); + describe('getSupportedElicitationModes', () => { test('should support nothing when capabilities are undefined', () => { const result = getSupportedElicitationModes(undefined);