diff --git a/src/main/events.ts b/src/main/events.ts index 12b8812b2..f58649c1e 100644 --- a/src/main/events.ts +++ b/src/main/events.ts @@ -112,7 +112,10 @@ export const MCP_EVENTS = { TOOL_CALL_RESULT: 'mcp:tool-call-result', SERVER_STATUS_CHANGED: 'mcp:server-status-changed', CLIENT_LIST_UPDATED: 'mcp:client-list-updated', - INITIALIZED: 'mcp:initialized' // 新增:MCP初始化完成事件 + INITIALIZED: 'mcp:initialized', // 新增:MCP初始化完成事件 + SAMPLING_REQUEST: 'mcp:sampling-request', + SAMPLING_DECISION: 'mcp:sampling-decision', + SAMPLING_CANCELLED: 'mcp:sampling-cancelled' } // 同步相关事件 diff --git a/src/main/presenter/mcpPresenter/index.ts b/src/main/presenter/mcpPresenter/index.ts index 733bc22f4..d56711713 100644 --- a/src/main/presenter/mcpPresenter/index.ts +++ b/src/main/presenter/mcpPresenter/index.ts @@ -8,7 +8,9 @@ import { Prompt, ResourceListEntry, Resource, - PromptListEntry + PromptListEntry, + McpSamplingRequestPayload, + McpSamplingDecision } from '@shared/presenter' import { ServerManager } from './serverManager' import { ToolManager } from './toolManager' @@ -85,6 +87,10 @@ export class McpPresenter implements IMCPPresenter { private isInitialized: boolean = false // McpRouter private mcprouter?: McpRouterManager + private pendingSamplingRequests = new Map< + string, + { resolve: (decision: McpSamplingDecision) => void; reject: (error: Error) => void } + >() constructor(configPresenter?: IConfigPresenter) { console.log('Initializing MCP Presenter') @@ -580,6 +586,60 @@ export class McpPresenter implements IMCPPresenter { return { content: formattedContent, rawData: toolCallResult } } + async handleSamplingRequest(request: McpSamplingRequestPayload): Promise { + if (!request || !request.requestId) { + throw new Error('Invalid sampling request: missing requestId') + } + + return new Promise((resolve, reject) => { + try { + this.pendingSamplingRequests.set(request.requestId, { resolve, reject }) + eventBus.sendToRenderer(MCP_EVENTS.SAMPLING_REQUEST, SendTarget.DEFAULT_TAB, request) + } catch (error) { + this.pendingSamplingRequests.delete(request.requestId) + reject(error instanceof Error ? error : new Error(String(error))) + } + }) + } + + async submitSamplingDecision(decision: McpSamplingDecision): Promise { + if (!decision || !decision.requestId) { + throw new Error('Invalid sampling decision: missing requestId') + } + + const pending = this.pendingSamplingRequests.get(decision.requestId) + if (!pending) { + console.warn( + `[MCP] Sampling request ${decision.requestId} not found when submitting decision` + ) + return + } + + this.pendingSamplingRequests.delete(decision.requestId) + pending.resolve(decision) + + eventBus.sendToRenderer(MCP_EVENTS.SAMPLING_DECISION, SendTarget.ALL_WINDOWS, decision) + } + + async cancelSamplingRequest(requestId: string, reason?: string): Promise { + if (!requestId) { + return + } + + const pending = this.pendingSamplingRequests.get(requestId) + if (!pending) { + return + } + + this.pendingSamplingRequests.delete(requestId) + pending.reject(new Error(reason ?? 'Sampling request cancelled')) + + eventBus.sendToRenderer(MCP_EVENTS.SAMPLING_CANCELLED, SendTarget.ALL_WINDOWS, { + requestId, + reason: reason ?? 'cancelled' + }) + } + // Convert MCPToolDefinition to MCPTool private mcpToolDefinitionToMcpTool( toolDefinition: MCPToolDefinition, diff --git a/src/main/presenter/mcpPresenter/mcpClient.ts b/src/main/presenter/mcpPresenter/mcpClient.ts index 647c4941c..1ceb05e6d 100644 --- a/src/main/presenter/mcpPresenter/mcpClient.ts +++ b/src/main/presenter/mcpPresenter/mcpClient.ts @@ -8,8 +8,12 @@ import { PromptListChangedNotificationSchema, ResourceListChangedNotificationSchema, ResourceUpdatedNotificationSchema, - LoggingMessageNotificationSchema + LoggingMessageNotificationSchema, + CreateMessageRequestSchema, + ErrorCode, + McpError } from '@modelcontextprotocol/sdk/types.js' +import type { CreateMessageRequest, CreateMessageResult } from '@modelcontextprotocol/sdk/types.js' import { eventBus, SendTarget } from '@/eventbus' import { MCP_EVENTS } from '@/events' import path from 'path' @@ -25,9 +29,19 @@ import { Tool, Prompt, ResourceListEntry, - Resource + Resource, + ChatMessage, + McpSamplingRequestPayload, + McpSamplingDecision } from '@shared/presenter' +const ALLOWED_SAMPLING_IMAGE_MIME_TYPES = new Set([ + 'image/png', + 'image/jpeg', + 'image/gif', + 'image/webp' +]) + // TODO: resources 和 prompts 的类型,Notifactions 的类型 https://github.com/modelcontextprotocol/typescript-sdk/blob/main/src/examples/client/simpleStreamableHttp.ts // Simple OAuth provider for handling Bearer Token class SimpleOAuthProvider { @@ -58,6 +72,12 @@ interface SessionError extends Error { isSessionExpired?: boolean } +interface RequestHandlerContext { + signal?: AbortSignal + requestId?: string | number + [key: string]: unknown +} + // Helper function to check if error is session-related function isSessionError(error: unknown): error is SessionError { if (error instanceof Error) { @@ -590,7 +610,8 @@ export class McpClient { capabilities: { resources: {}, tools: {}, - prompts: {} + prompts: {}, + sampling: {} } } ) @@ -598,6 +619,11 @@ export class McpClient { // 设置通知处理器 this.registerNotificationHandlers() + // 注册采样请求处理器 + this.client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => { + return this.handleSamplingCreateMessage(request, extra) + }) + // 设置连接超时 const timeoutPromise = new Promise((_, reject) => { this.connectionTimeout = setTimeout( @@ -764,6 +790,296 @@ export class McpClient { }) } + private async handleSamplingCreateMessage( + request: CreateMessageRequest, + extra: RequestHandlerContext + ): Promise { + const params = request.params ?? {} + const requestId = this.resolveSamplingRequestId(extra) + const { payload, chatMessages } = this.prepareSamplingContext(requestId, params) + + const decisionPromise = presenter.mcpPresenter.handleSamplingRequest(payload) + const signal = extra?.signal as AbortSignal | undefined + + let decision: McpSamplingDecision + if (signal) { + decision = await new Promise((resolve, reject) => { + const onAbort = () => { + signal.removeEventListener('abort', onAbort) + void presenter.mcpPresenter + .cancelSamplingRequest(payload.requestId, 'cancelled by server') + .catch((error) => { + console.warn(`[MCP] Failed to cancel sampling request ${payload.requestId}:`, error) + }) + reject(new McpError(ErrorCode.RequestTimeout, 'Sampling request cancelled')) + } + + if (signal.aborted) { + onAbort() + return + } + + signal.addEventListener('abort', onAbort, { once: true }) + decisionPromise + .then((value) => { + signal.removeEventListener('abort', onAbort) + resolve(value) + }) + .catch((error) => { + signal.removeEventListener('abort', onAbort) + reject(error) + }) + }) + } else { + decision = await decisionPromise + } + + if (!decision.approved) { + throw new McpError(ErrorCode.InvalidRequest, 'User rejected sampling request') + } + + if (!decision.providerId || !decision.modelId) { + throw new McpError(ErrorCode.InvalidParams, 'No model selected for sampling request') + } + + let assistantText = '' + try { + assistantText = await presenter.llmproviderPresenter.generateCompletionStandalone( + decision.providerId, + chatMessages, + decision.modelId, + undefined, + params.maxTokens + ) + } catch (error) { + console.error(`[MCP] Sampling request failed for server ${this.serverName}:`, error) + throw new McpError( + ErrorCode.InternalError, + error instanceof Error ? error.message : 'Sampling request failed' + ) + } + + const modelName = + this.resolveModelDisplayName(decision.providerId, decision.modelId) ?? decision.modelId + + const result: CreateMessageResult = { + role: 'assistant', + model: modelName, + stopReason: 'endTurn', + content: { + type: 'text', + text: assistantText ?? '' + } + } + + return result + } + + private resolveSamplingRequestId(extra: RequestHandlerContext): string { + const rawId = extra?.requestId + if (typeof rawId === 'string' || typeof rawId === 'number') { + return String(rawId) + } + + return `${this.serverName}-${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 8)}` + } + + private prepareSamplingContext( + requestId: string, + params: CreateMessageRequest['params'] + ): { payload: McpSamplingRequestPayload; chatMessages: ChatMessage[] } { + const payload: McpSamplingRequestPayload = { + requestId, + serverName: this.serverName, + serverLabel: this.getServerLabel(), + systemPrompt: typeof params?.systemPrompt === 'string' ? params.systemPrompt : undefined, + maxTokens: typeof params?.maxTokens === 'number' ? params.maxTokens : undefined, + modelPreferences: this.normalizeModelPreferences(params?.modelPreferences), + requiresVision: false, + messages: [] + } + + const chatMessages: ChatMessage[] = [] + + if (payload.systemPrompt) { + chatMessages.push({ role: 'system', content: payload.systemPrompt }) + } + + const messageList = Array.isArray(params?.messages) ? params.messages : [] + + for (const message of messageList) { + if (!message || (message.role !== 'user' && message.role !== 'assistant')) { + continue + } + + const rawContent = message.content + if (!rawContent || typeof rawContent !== 'object' || !('type' in rawContent)) { + throw new McpError(ErrorCode.InvalidParams, 'Invalid sampling message content received') + } + + const content = rawContent as { type: string } & Record + + if (content.type === 'text') { + const text = typeof content.text === 'string' ? content.text : '' + payload.messages.push({ role: message.role, type: 'text', text }) + chatMessages.push({ role: message.role, content: text }) + } else if (content.type === 'image') { + const rawMimeType = typeof content.mimeType === 'string' ? content.mimeType : undefined + const normalizedMimeType = rawMimeType?.toLowerCase() + + if (normalizedMimeType && !ALLOWED_SAMPLING_IMAGE_MIME_TYPES.has(normalizedMimeType)) { + throw new McpError( + ErrorCode.InvalidParams, + `Unsupported sampling image mime type: ${rawMimeType}` + ) + } + + const mimeType = normalizedMimeType ?? 'image/png' + const data = this.sanitizeSamplingImageData(content.data) + const dataUrl = `data:${mimeType};base64,${data}` + payload.messages.push({ + role: message.role, + type: 'image', + dataUrl, + mimeType + }) + payload.requiresVision = true + chatMessages.push({ + role: message.role, + content: [ + { + type: 'image_url', + image_url: { url: dataUrl, detail: 'auto' as const } + } + ] + }) + } else if (content.type === 'audio') { + throw new McpError( + ErrorCode.InvalidParams, + 'Audio sampling content is not supported by this client' + ) + } else { + throw new McpError( + ErrorCode.InvalidParams, + `Unsupported sampling content type: ${String((content as { type?: unknown }).type)}` + ) + } + } + + return { payload, chatMessages } + } + + private sanitizeSamplingImageData(rawData: unknown): string { + if (typeof rawData !== 'string') { + throw new McpError(ErrorCode.InvalidParams, 'Invalid sampling image payload received') + } + + const sanitized = rawData.replace(/\s+/g, '') + + if (!sanitized) { + throw new McpError(ErrorCode.InvalidParams, 'Invalid sampling image payload received') + } + + if (sanitized.length % 4 !== 0 || /[^A-Za-z0-9+/=]/.test(sanitized)) { + throw new McpError(ErrorCode.InvalidParams, 'Invalid sampling image payload received') + } + + let decoded: Buffer + + try { + decoded = Buffer.from(sanitized, 'base64') + } catch { + throw new McpError(ErrorCode.InvalidParams, 'Invalid sampling image payload received') + } + + if (!decoded.length) { + throw new McpError(ErrorCode.InvalidParams, 'Invalid sampling image payload received') + } + + const reencoded = decoded.toString('base64') + + if (reencoded.replace(/=+$/, '') !== sanitized.replace(/=+$/, '')) { + throw new McpError(ErrorCode.InvalidParams, 'Invalid sampling image payload received') + } + + return sanitized + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + private normalizeModelPreferences( + preferences: any + ): McpSamplingRequestPayload['modelPreferences'] { + if (!preferences || typeof preferences !== 'object') { + return undefined + } + + const normalized: McpSamplingRequestPayload['modelPreferences'] = {} + + if (typeof preferences.costPriority === 'number') { + normalized.costPriority = preferences.costPriority + } + if (typeof preferences.speedPriority === 'number') { + normalized.speedPriority = preferences.speedPriority + } + if (typeof preferences.intelligencePriority === 'number') { + normalized.intelligencePriority = preferences.intelligencePriority + } + if (Array.isArray(preferences.hints)) { + normalized.hints = preferences.hints.map((hint: { name?: unknown }) => ({ + name: typeof hint?.name === 'string' ? hint.name : undefined + })) + } + + if ( + normalized.costPriority === undefined && + normalized.speedPriority === undefined && + normalized.intelligencePriority === undefined && + (!normalized.hints || normalized.hints.length === 0) + ) { + return undefined + } + + return normalized + } + + private getServerLabel(): string | undefined { + const config = this.serverConfig + if (!config) { + return undefined + } + + const candidates: Array = [ + typeof config['descriptions'] === 'string' ? (config['descriptions'] as string) : undefined, + typeof config['description'] === 'string' ? (config['description'] as string) : undefined, + typeof config['name'] === 'string' ? (config['name'] as string) : undefined + ] + + return candidates.find((label) => label && label.trim().length > 0) + } + + private resolveModelDisplayName(providerId: string, modelId: string): string | undefined { + try { + const models = presenter.configPresenter.getProviderModels(providerId) || [] + const match = models.find((model) => model.id === modelId) + if (match?.name) { + return match.name + } + + const customModels = presenter.configPresenter.getCustomModels?.(providerId) || [] + const customMatch = customModels.find((model) => model.id === modelId) + if (customMatch?.name) { + return customMatch.name + } + } catch (error) { + console.warn( + `[MCP] Failed to resolve model display name for ${providerId}/${modelId}:`, + error + ) + } + + return undefined + } + // 检查服务器是否正在运行 isServerRunning(): boolean { return this.isConnected && !!this.client diff --git a/src/main/presenter/threadPresenter/index.ts b/src/main/presenter/threadPresenter/index.ts index 69a15f8b7..dbb03c5f1 100644 --- a/src/main/presenter/threadPresenter/index.ts +++ b/src/main/presenter/threadPresenter/index.ts @@ -436,6 +436,30 @@ export class ThreadPresenter implements IThreadPresenter { } this.finalizeLastBlock(state) + const permissionExtra: Record = { + needsUserAction: true + } + + if (permission_request?.permissionType) { + permissionExtra.permissionType = permission_request.permissionType + } + if (permission_request) { + permissionExtra.permissionRequest = JSON.stringify(permission_request) + if (permission_request.toolName) { + permissionExtra.toolName = permission_request.toolName + } + if (permission_request.serverName) { + permissionExtra.serverName = permission_request.serverName + } + } else { + if (tool_call_name) { + permissionExtra.toolName = tool_call_name + } + if (tool_call_server_name) { + permissionExtra.serverName = tool_call_server_name + } + } + state.message.content.push({ type: 'action', content: tool_call_response || '', @@ -450,7 +474,7 @@ export class ThreadPresenter implements IThreadPresenter { server_icons: tool_call_server_icons, server_description: tool_call_server_description }, - extra + extra: permissionExtra }) if (state) { diff --git a/src/renderer/floating/index.html b/src/renderer/floating/index.html index 0df14ab22..d9eb359f2 100644 --- a/src/renderer/floating/index.html +++ b/src/renderer/floating/index.html @@ -5,6 +5,10 @@ + Floating Button