From a83d9cbf9e39dc0dcaa4d04cfde067a3bf20626c Mon Sep 17 00:00:00 2001 From: duskzhen Date: Wed, 29 Oct 2025 15:00:52 +0800 Subject: [PATCH 01/13] feat(mcp): add sampling request handling --- src/main/events.ts | 5 +- src/main/presenter/mcpPresenter/index.ts | 62 +++- src/main/presenter/mcpPresenter/mcpClient.ts | 269 +++++++++++++++++- src/renderer/src/App.vue | 2 + .../src/components/mcp/McpSamplingDialog.vue | 201 +++++++++++++ src/renderer/src/events.ts | 5 +- src/renderer/src/i18n/en-US/mcp.json | 29 ++ src/renderer/src/stores/mcpSampling.ts | 162 +++++++++++ src/shared/types/core/mcp.ts | 45 +++ .../types/presenters/legacy.presenters.d.ts | 8 + 10 files changed, 782 insertions(+), 6 deletions(-) create mode 100644 src/renderer/src/components/mcp/McpSamplingDialog.vue create mode 100644 src/renderer/src/stores/mcpSampling.ts diff --git a/src/main/events.ts b/src/main/events.ts index b29775bfa..3acac95bd 100644 --- a/src/main/events.ts +++ b/src/main/events.ts @@ -111,7 +111,10 @@ export const MCP_EVENTS = { TOOL_CALL_RESULT: 'mcp:tool-call-result', SERVER_STATUS_CHANGED: 'mcp:server-status-changed', CLIENT_LIST_UPDATED: 'mcp:client-list-updated', - INITIALIZED: 'mcp:initialized' // 新增:MCP初始化完成事件 + INITIALIZED: 'mcp:initialized', // 新增:MCP初始化完成事件 + SAMPLING_REQUEST: 'mcp:sampling-request', + SAMPLING_DECISION: 'mcp:sampling-decision', + SAMPLING_CANCELLED: 'mcp:sampling-cancelled' } // 同步相关事件 diff --git a/src/main/presenter/mcpPresenter/index.ts b/src/main/presenter/mcpPresenter/index.ts index 733bc22f4..d56711713 100644 --- a/src/main/presenter/mcpPresenter/index.ts +++ b/src/main/presenter/mcpPresenter/index.ts @@ -8,7 +8,9 @@ import { Prompt, ResourceListEntry, Resource, - PromptListEntry + PromptListEntry, + McpSamplingRequestPayload, + McpSamplingDecision } from '@shared/presenter' import { ServerManager } from './serverManager' import { ToolManager } from './toolManager' @@ -85,6 +87,10 @@ export class McpPresenter implements IMCPPresenter { private isInitialized: boolean = false // McpRouter private mcprouter?: McpRouterManager + private pendingSamplingRequests = new Map< + string, + { resolve: (decision: McpSamplingDecision) => void; reject: (error: Error) => void } + >() constructor(configPresenter?: IConfigPresenter) { console.log('Initializing MCP Presenter') @@ -580,6 +586,60 @@ export class McpPresenter implements IMCPPresenter { return { content: formattedContent, rawData: toolCallResult } } + async handleSamplingRequest(request: McpSamplingRequestPayload): Promise { + if (!request || !request.requestId) { + throw new Error('Invalid sampling request: missing requestId') + } + + return new Promise((resolve, reject) => { + try { + this.pendingSamplingRequests.set(request.requestId, { resolve, reject }) + eventBus.sendToRenderer(MCP_EVENTS.SAMPLING_REQUEST, SendTarget.DEFAULT_TAB, request) + } catch (error) { + this.pendingSamplingRequests.delete(request.requestId) + reject(error instanceof Error ? error : new Error(String(error))) + } + }) + } + + async submitSamplingDecision(decision: McpSamplingDecision): Promise { + if (!decision || !decision.requestId) { + throw new Error('Invalid sampling decision: missing requestId') + } + + const pending = this.pendingSamplingRequests.get(decision.requestId) + if (!pending) { + console.warn( + `[MCP] Sampling request ${decision.requestId} not found when submitting decision` + ) + return + } + + this.pendingSamplingRequests.delete(decision.requestId) + pending.resolve(decision) + + eventBus.sendToRenderer(MCP_EVENTS.SAMPLING_DECISION, SendTarget.ALL_WINDOWS, decision) + } + + async cancelSamplingRequest(requestId: string, reason?: string): Promise { + if (!requestId) { + return + } + + const pending = this.pendingSamplingRequests.get(requestId) + if (!pending) { + return + } + + this.pendingSamplingRequests.delete(requestId) + pending.reject(new Error(reason ?? 'Sampling request cancelled')) + + eventBus.sendToRenderer(MCP_EVENTS.SAMPLING_CANCELLED, SendTarget.ALL_WINDOWS, { + requestId, + reason: reason ?? 'cancelled' + }) + } + // Convert MCPToolDefinition to MCPTool private mcpToolDefinitionToMcpTool( toolDefinition: MCPToolDefinition, diff --git a/src/main/presenter/mcpPresenter/mcpClient.ts b/src/main/presenter/mcpPresenter/mcpClient.ts index 647c4941c..98d667cf5 100644 --- a/src/main/presenter/mcpPresenter/mcpClient.ts +++ b/src/main/presenter/mcpPresenter/mcpClient.ts @@ -8,8 +8,12 @@ import { PromptListChangedNotificationSchema, ResourceListChangedNotificationSchema, ResourceUpdatedNotificationSchema, - LoggingMessageNotificationSchema + LoggingMessageNotificationSchema, + CreateMessageRequestSchema, + ErrorCode, + McpError } from '@modelcontextprotocol/sdk/types.js' +import type { CreateMessageRequest, CreateMessageResult } from '@modelcontextprotocol/sdk/types.js' import { eventBus, SendTarget } from '@/eventbus' import { MCP_EVENTS } from '@/events' import path from 'path' @@ -25,7 +29,10 @@ import { Tool, Prompt, ResourceListEntry, - Resource + Resource, + ChatMessage, + McpSamplingRequestPayload, + McpSamplingDecision } from '@shared/presenter' // TODO: resources 和 prompts 的类型,Notifactions 的类型 https://github.com/modelcontextprotocol/typescript-sdk/blob/main/src/examples/client/simpleStreamableHttp.ts @@ -58,6 +65,12 @@ interface SessionError extends Error { isSessionExpired?: boolean } +interface RequestHandlerContext { + signal?: AbortSignal + requestId?: string | number + [key: string]: unknown +} + // Helper function to check if error is session-related function isSessionError(error: unknown): error is SessionError { if (error instanceof Error) { @@ -590,7 +603,8 @@ export class McpClient { capabilities: { resources: {}, tools: {}, - prompts: {} + prompts: {}, + sampling: {} } } ) @@ -598,6 +612,11 @@ export class McpClient { // 设置通知处理器 this.registerNotificationHandlers() + // 注册采样请求处理器 + this.client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => { + return this.handleSamplingCreateMessage(request, extra) + }) + // 设置连接超时 const timeoutPromise = new Promise((_, reject) => { this.connectionTimeout = setTimeout( @@ -764,6 +783,250 @@ export class McpClient { }) } + private async handleSamplingCreateMessage( + request: CreateMessageRequest, + extra: RequestHandlerContext + ): Promise { + const params = request.params ?? {} + const requestId = this.resolveSamplingRequestId(extra) + const { payload, chatMessages } = this.prepareSamplingContext(requestId, params) + + const decisionPromise = presenter.mcpPresenter.handleSamplingRequest(payload) + const signal = extra?.signal as AbortSignal | undefined + + let decision: McpSamplingDecision + if (signal) { + decision = await new Promise((resolve, reject) => { + const onAbort = () => { + signal.removeEventListener('abort', onAbort) + void presenter.mcpPresenter + .cancelSamplingRequest(payload.requestId, 'cancelled by server') + .catch((error) => { + console.warn(`[MCP] Failed to cancel sampling request ${payload.requestId}:`, error) + }) + reject(new McpError(ErrorCode.RequestTimeout, 'Sampling request cancelled')) + } + + if (signal.aborted) { + onAbort() + return + } + + signal.addEventListener('abort', onAbort, { once: true }) + decisionPromise + .then((value) => { + signal.removeEventListener('abort', onAbort) + resolve(value) + }) + .catch((error) => { + signal.removeEventListener('abort', onAbort) + reject(error) + }) + }) + } else { + decision = await decisionPromise + } + + if (!decision.approved) { + throw new McpError(ErrorCode.InvalidRequest, 'User rejected sampling request') + } + + if (!decision.providerId || !decision.modelId) { + throw new McpError(ErrorCode.InvalidParams, 'No model selected for sampling request') + } + + let assistantText = '' + try { + assistantText = await presenter.llmproviderPresenter.generateCompletionStandalone( + decision.providerId, + chatMessages, + decision.modelId, + undefined, + params.maxTokens + ) + } catch (error) { + console.error(`[MCP] Sampling request failed for server ${this.serverName}:`, error) + throw new McpError( + ErrorCode.InternalError, + error instanceof Error ? error.message : 'Sampling request failed' + ) + } + + const modelName = + this.resolveModelDisplayName(decision.providerId, decision.modelId) ?? decision.modelId + + const result: CreateMessageResult = { + role: 'assistant', + model: modelName, + stopReason: 'endTurn', + content: { + type: 'text', + text: assistantText ?? '' + } + } + + return result + } + + private resolveSamplingRequestId(extra: RequestHandlerContext): string { + const rawId = extra?.requestId + if (typeof rawId === 'string' || typeof rawId === 'number') { + return String(rawId) + } + + return `${this.serverName}-${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 8)}` + } + + private prepareSamplingContext( + requestId: string, + params: CreateMessageRequest['params'] + ): { payload: McpSamplingRequestPayload; chatMessages: ChatMessage[] } { + const payload: McpSamplingRequestPayload = { + requestId, + serverName: this.serverName, + serverLabel: this.getServerLabel(), + systemPrompt: typeof params?.systemPrompt === 'string' ? params.systemPrompt : undefined, + maxTokens: typeof params?.maxTokens === 'number' ? params.maxTokens : undefined, + modelPreferences: this.normalizeModelPreferences(params?.modelPreferences), + requiresVision: false, + messages: [] + } + + const chatMessages: ChatMessage[] = [] + + if (payload.systemPrompt) { + chatMessages.push({ role: 'system', content: payload.systemPrompt }) + } + + const messageList = Array.isArray(params?.messages) ? params.messages : [] + + for (const message of messageList) { + if (!message || (message.role !== 'user' && message.role !== 'assistant')) { + continue + } + + const rawContent = message.content + if (!rawContent || typeof rawContent !== 'object' || !('type' in rawContent)) { + throw new McpError(ErrorCode.InvalidParams, 'Invalid sampling message content received') + } + + const content = rawContent as { type: string } & Record + + if (content.type === 'text') { + const text = typeof content.text === 'string' ? content.text : '' + payload.messages.push({ role: message.role, type: 'text', text }) + chatMessages.push({ role: message.role, content: text }) + } else if (content.type === 'image') { + const mimeType = typeof content.mimeType === 'string' ? content.mimeType : 'image/png' + const data = typeof content.data === 'string' ? content.data : '' + const dataUrl = `data:${mimeType};base64,${data}` + payload.messages.push({ + role: message.role, + type: 'image', + dataUrl, + mimeType + }) + payload.requiresVision = true + chatMessages.push({ + role: message.role, + content: [ + { + type: 'image_url', + image_url: { url: dataUrl, detail: 'auto' as const } + } + ] + }) + } else if (content.type === 'audio') { + throw new McpError( + ErrorCode.InvalidParams, + 'Audio sampling content is not supported by this client' + ) + } else { + throw new McpError( + ErrorCode.InvalidParams, + `Unsupported sampling content type: ${String((content as { type?: unknown }).type)}` + ) + } + } + + return { payload, chatMessages } + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + private normalizeModelPreferences( + preferences: any + ): McpSamplingRequestPayload['modelPreferences'] { + if (!preferences || typeof preferences !== 'object') { + return undefined + } + + const normalized: McpSamplingRequestPayload['modelPreferences'] = {} + + if (typeof preferences.costPriority === 'number') { + normalized.costPriority = preferences.costPriority + } + if (typeof preferences.speedPriority === 'number') { + normalized.speedPriority = preferences.speedPriority + } + if (typeof preferences.intelligencePriority === 'number') { + normalized.intelligencePriority = preferences.intelligencePriority + } + if (Array.isArray(preferences.hints)) { + normalized.hints = preferences.hints.map((hint: { name?: unknown }) => ({ + name: typeof hint?.name === 'string' ? hint.name : undefined + })) + } + + if ( + normalized.costPriority === undefined && + normalized.speedPriority === undefined && + normalized.intelligencePriority === undefined && + (!normalized.hints || normalized.hints.length === 0) + ) { + return undefined + } + + return normalized + } + + private getServerLabel(): string | undefined { + const config = this.serverConfig + if (!config) { + return undefined + } + + const candidates: Array = [ + typeof config['descriptions'] === 'string' ? (config['descriptions'] as string) : undefined, + typeof config['description'] === 'string' ? (config['description'] as string) : undefined, + typeof config['name'] === 'string' ? (config['name'] as string) : undefined + ] + + return candidates.find((label) => label && label.trim().length > 0) + } + + private resolveModelDisplayName(providerId: string, modelId: string): string | undefined { + try { + const models = presenter.configPresenter.getProviderModels(providerId) || [] + const match = models.find((model) => model.id === modelId) + if (match?.name) { + return match.name + } + + const customModels = presenter.configPresenter.getCustomModels?.(providerId) || [] + const customMatch = customModels.find((model) => model.id === modelId) + if (customMatch?.name) { + return customMatch.name + } + } catch (error) { + console.warn( + `[MCP] Failed to resolve model display name for ${providerId}/${modelId}:`, + error + ) + } + + return undefined + } + // 检查服务器是否正在运行 isServerRunning(): boolean { return this.isConnected && !!this.client diff --git a/src/renderer/src/App.vue b/src/renderer/src/App.vue index 85948a14d..da2e292fc 100644 --- a/src/renderer/src/App.vue +++ b/src/renderer/src/App.vue @@ -18,6 +18,7 @@ import ThreadView from '@/components/ThreadView.vue' import ModelCheckDialog from '@/components/settings/ModelCheckDialog.vue' import { useModelCheckStore } from '@/stores/modelCheck' import MessageDialog from './components/ui/MessageDialog.vue' +import McpSamplingDialog from '@/components/mcp/McpSamplingDialog.vue' import 'vue-sonner/style.css' // vue-sonner v2 requires this import const route = useRoute() @@ -338,6 +339,7 @@ onBeforeUnmount(() => { + diff --git a/src/renderer/src/components/mcp/McpSamplingDialog.vue b/src/renderer/src/components/mcp/McpSamplingDialog.vue new file mode 100644 index 000000000..b813c33ff --- /dev/null +++ b/src/renderer/src/components/mcp/McpSamplingDialog.vue @@ -0,0 +1,201 @@ + + + diff --git a/src/renderer/src/events.ts b/src/renderer/src/events.ts index 6553306cb..f240fae04 100644 --- a/src/renderer/src/events.ts +++ b/src/renderer/src/events.ts @@ -70,7 +70,10 @@ export const MCP_EVENTS = { SERVER_STOPPED: 'mcp:server-stopped', CONFIG_CHANGED: 'mcp:config-changed', TOOL_CALL_RESULT: 'mcp:tool-call-result', - SERVER_STATUS_CHANGED: 'mcp:server-status-changed' + SERVER_STATUS_CHANGED: 'mcp:server-status-changed', + SAMPLING_REQUEST: 'mcp:sampling-request', + SAMPLING_DECISION: 'mcp:sampling-decision', + SAMPLING_CANCELLED: 'mcp:sampling-cancelled' } // 新增会议相关事件 diff --git a/src/renderer/src/i18n/en-US/mcp.json b/src/renderer/src/i18n/en-US/mcp.json index 34bf142fe..1431332f4 100644 --- a/src/renderer/src/i18n/en-US/mcp.json +++ b/src/renderer/src/i18n/en-US/mcp.json @@ -136,6 +136,35 @@ "startServer": "Start the server", "stopServer": "Stop the server", "stopped": "Stopped", + "sampling": { + "title": "Sampling request from {server}", + "unknownServer": "Unknown server", + "description": "Review the context shared by the MCP server and choose whether to generate a response.", + "systemPrompt": "System prompt", + "messagesTitle": "Conversation context", + "preferencesTitle": "Model preferences", + "approve": "Approve", + "reject": "Reject", + "confirm": "Send response", + "confirming": "Sending…", + "visionWarning": "The selected model does not support vision input. Please pick a vision-capable model before continuing.", + "selectedModelLabel": "Responding with {model} ({provider})", + "unsupportedMessage": "This content type is not supported.", + "imageAlt": "Image {index}", + "unknownMime": "Unknown MIME type", + "unknownHint": "Unnamed hint", + "contentType": { + "text": "Text", + "image": "Image", + "audio": "Audio" + }, + "preference": { + "cost": "Cost priority", + "speed": "Speed priority", + "intelligence": "Intelligence priority", + "hints": "Model hints" + } + }, "tabs": { "servers": "Servers", "tools": "Tools" diff --git a/src/renderer/src/stores/mcpSampling.ts b/src/renderer/src/stores/mcpSampling.ts new file mode 100644 index 000000000..5d0ab6744 --- /dev/null +++ b/src/renderer/src/stores/mcpSampling.ts @@ -0,0 +1,162 @@ +import { defineStore } from 'pinia' +import { ref, computed, onMounted, onUnmounted } from 'vue' +import { usePresenter } from '@/composables/usePresenter' +import { MCP_EVENTS } from '@/events' +import type { + McpSamplingDecision, + McpSamplingRequestPayload, + RENDERER_MODEL_META +} from '@shared/presenter' +import { useChatStore } from '@/stores/chat' +import { useSettingsStore } from '@/stores/settings' + +export const useMcpSamplingStore = defineStore('mcpSampling', () => { + const mcpPresenter = usePresenter('mcpPresenter') + const chatStore = useChatStore() + const settingsStore = useSettingsStore() + + const request = ref(null) + const isOpen = ref(false) + const isSubmitting = ref(false) + const isChoosingModel = ref(false) + const selectedProviderId = ref(null) + const selectedModel = ref(null) + + const requiresVision = computed(() => request.value?.requiresVision ?? false) + const selectedModelSupportsVision = computed(() => selectedModel.value?.vision ?? false) + + const resetSelection = () => { + selectedProviderId.value = chatStore.chatConfig.providerId || null + const providerId = selectedProviderId.value + if (!providerId) { + selectedModel.value = null + return + } + + const providerEntry = settingsStore.enabledModels.find( + (entry) => entry.providerId === providerId + ) + const activeModelId = chatStore.chatConfig.modelId + const activeModel = providerEntry?.models.find((model) => model.id === activeModelId) + + if (activeModel) { + selectedModel.value = activeModel + return + } + + selectedModel.value = providerEntry?.models?.[0] ?? null + } + + const openRequest = (payload: McpSamplingRequestPayload) => { + request.value = payload + isOpen.value = true + isChoosingModel.value = false + isSubmitting.value = false + resetSelection() + } + + const closeRequest = () => { + isOpen.value = false + isChoosingModel.value = false + isSubmitting.value = false + request.value = null + selectedProviderId.value = null + selectedModel.value = null + } + + const beginApprove = () => { + isChoosingModel.value = true + } + + const selectModel = (model: RENDERER_MODEL_META, providerId: string) => { + selectedModel.value = model + selectedProviderId.value = providerId + } + + const submitDecision = async (decision: McpSamplingDecision) => { + if (!request.value) { + return + } + + isSubmitting.value = true + try { + await mcpPresenter.submitSamplingDecision(decision) + closeRequest() + } catch (error) { + console.error('[MCP Sampling] Failed to submit decision:', error) + isSubmitting.value = false + } + } + + const confirmApproval = async () => { + if (!request.value || !selectedProviderId.value || !selectedModel.value) { + return + } + + await submitDecision({ + requestId: request.value.requestId, + approved: true, + providerId: selectedProviderId.value, + modelId: selectedModel.value.id + }) + } + + const rejectRequest = async () => { + if (!request.value) { + return + } + + await submitDecision({ + requestId: request.value.requestId, + approved: false, + reason: 'User rejected sampling request' + }) + } + + const handleSamplingRequest = (_event: unknown, payload: McpSamplingRequestPayload) => { + openRequest(payload) + } + + const handleSamplingCancelled = (_event: unknown, payload: { requestId: string }) => { + if (request.value && payload.requestId === request.value.requestId) { + closeRequest() + } + } + + const handleSamplingDecision = (_event: unknown, payload: McpSamplingDecision) => { + if (request.value && payload.requestId === request.value.requestId) { + closeRequest() + } + } + + onMounted(() => { + window.electron.ipcRenderer.on(MCP_EVENTS.SAMPLING_REQUEST, handleSamplingRequest) + window.electron.ipcRenderer.on(MCP_EVENTS.SAMPLING_CANCELLED, handleSamplingCancelled) + window.electron.ipcRenderer.on(MCP_EVENTS.SAMPLING_DECISION, handleSamplingDecision) + }) + + onUnmounted(() => { + window.electron.ipcRenderer.removeListener(MCP_EVENTS.SAMPLING_REQUEST, handleSamplingRequest) + window.electron.ipcRenderer.removeListener( + MCP_EVENTS.SAMPLING_CANCELLED, + handleSamplingCancelled + ) + window.electron.ipcRenderer.removeListener(MCP_EVENTS.SAMPLING_DECISION, handleSamplingDecision) + }) + + return { + request, + isOpen, + isSubmitting, + isChoosingModel, + requiresVision, + selectedModelSupportsVision, + selectedProviderId, + selectedModel, + beginApprove, + selectModel, + confirmApproval, + rejectRequest, + closeRequest + } +}) diff --git a/src/shared/types/core/mcp.ts b/src/shared/types/core/mcp.ts index 6175e150a..7f35e01e2 100644 --- a/src/shared/types/core/mcp.ts +++ b/src/shared/types/core/mcp.ts @@ -69,3 +69,48 @@ export interface MCPToolResponse { description: string } } + +export type McpSamplingMessageType = 'text' | 'image' | 'audio' + +export interface McpSamplingMessage { + role: 'user' | 'assistant' + type: McpSamplingMessageType + /** + * Plain text content when the message type is `text`. + */ + text?: string + /** + * Base64 payload rendered as a data URL in the renderer when type is `image` or `audio`. + */ + dataUrl?: string + /** + * MIME type of the binary payload when available. + */ + mimeType?: string +} + +export interface McpSamplingModelPreferences { + costPriority?: number + speedPriority?: number + intelligencePriority?: number + hints?: Array<{ name?: string | null }> +} + +export interface McpSamplingRequestPayload { + requestId: string + serverName: string + serverLabel?: string + systemPrompt?: string + maxTokens?: number + modelPreferences?: McpSamplingModelPreferences + requiresVision: boolean + messages: McpSamplingMessage[] +} + +export interface McpSamplingDecision { + requestId: string + approved: boolean + providerId?: string + modelId?: string + reason?: string +} diff --git a/src/shared/types/presenters/legacy.presenters.d.ts b/src/shared/types/presenters/legacy.presenters.d.ts index e9410be96..7f378e8bf 100644 --- a/src/shared/types/presenters/legacy.presenters.d.ts +++ b/src/shared/types/presenters/legacy.presenters.d.ts @@ -1206,6 +1206,11 @@ export interface MCPToolResponse { } } +export type McpSamplingMessage = import('../../core/mcp').McpSamplingMessage +export type McpSamplingRequestPayload = import('../../core/mcp').McpSamplingRequestPayload +export type McpSamplingDecision = import('../../core/mcp').McpSamplingDecision +export type McpSamplingModelPreferences = import('../../core/mcp').McpSamplingModelPreferences + /** Content item type */ export type MCPContentItem = MCPTextContent | MCPImageContent | MCPResourceContent @@ -1255,6 +1260,9 @@ export interface IMCPPresenter { getPrompt(prompt: PromptListEntry, args?: Record): Promise readResource(resource: ResourceListEntry): Promise callTool(request: MCPToolCall): Promise<{ content: string; rawData: MCPToolResponse }> + handleSamplingRequest(request: McpSamplingRequestPayload): Promise + submitSamplingDecision(decision: McpSamplingDecision): Promise + cancelSamplingRequest(requestId: string, reason?: string): Promise setMcpEnabled(enabled: boolean): Promise getMcpEnabled(): Promise resetToDefaultServers(): Promise From 2a527293b0520c17889f3d196b46fed7c650f554 Mon Sep 17 00:00:00 2001 From: duskzhen Date: Wed, 29 Oct 2025 15:10:14 +0800 Subject: [PATCH 02/13] test(mcp): cover sampling flows --- test/main/presenter/mcpClient.test.ts | 176 +++++++++++++++++++++++++- 1 file changed, 175 insertions(+), 1 deletion(-) diff --git a/test/main/presenter/mcpClient.test.ts b/test/main/presenter/mcpClient.test.ts index e53525d12..e1d6269d5 100644 --- a/test/main/presenter/mcpClient.test.ts +++ b/test/main/presenter/mcpClient.test.ts @@ -2,6 +2,7 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest' import { McpClient } from '../../../src/main/presenter/mcpPresenter/mcpClient' import path from 'path' import fs from 'fs' +import { ErrorCode } from '@modelcontextprotocol/sdk/types.js' // Mock electron modules vi.mock('electron', () => ({ @@ -33,14 +34,37 @@ vi.mock('../../../src/main/eventbus', () => ({ })) // Mock presenter +const presenterMocks = vi.hoisted(() => ({ + handleSamplingRequest: vi.fn(), + cancelSamplingRequest: vi.fn(), + generateCompletionStandalone: vi.fn(), + getProviderModels: vi.fn(), + getCustomModels: vi.fn() +})) + vi.mock('../../../src/main/presenter', () => ({ presenter: { configPresenter: { - getMcpServers: vi.fn() + getMcpServers: vi.fn(), + getProviderModels: presenterMocks.getProviderModels, + getCustomModels: presenterMocks.getCustomModels + }, + mcpPresenter: { + handleSamplingRequest: presenterMocks.handleSamplingRequest, + cancelSamplingRequest: presenterMocks.cancelSamplingRequest + }, + llmproviderPresenter: { + generateCompletionStandalone: presenterMocks.generateCompletionStandalone } } })) +const mockHandleSamplingRequest = presenterMocks.handleSamplingRequest +const mockCancelSamplingRequest = presenterMocks.cancelSamplingRequest +const mockGenerateCompletionStandalone = presenterMocks.generateCompletionStandalone +const mockGetProviderModels = presenterMocks.getProviderModels +const mockGetCustomModels = presenterMocks.getCustomModels + // Mock other dependencies that might be imported by mcpClient vi.mock('../../../src/main/events', () => ({ MCP_EVENTS: { @@ -90,6 +114,12 @@ describe('McpClient Runtime Command Processing Tests', () => { mockFsExistsSync = vi.mocked(fs.existsSync) vi.clearAllMocks() + mockHandleSamplingRequest.mockReset() + mockCancelSamplingRequest.mockReset() + mockGenerateCompletionStandalone.mockReset() + mockGetProviderModels.mockReset() + mockGetCustomModels.mockReset() + // Mock runtime paths to exist mockFsExistsSync.mockImplementation((filePath: string | Buffer | URL) => { const pathStr = String(filePath) @@ -390,4 +420,148 @@ describe('McpClient Runtime Command Processing Tests', () => { delete process.env.TEST_PATH }) }) + + describe('Sampling support', () => { + it('should prepare sampling payload and chat messages from request params', () => { + const client = new McpClient('server-one', { + type: 'stdio', + description: 'Sample server' + }) + + const params = { + systemPrompt: 'You are a helpful assistant.', + maxTokens: 128, + modelPreferences: { + costPriority: 0.5, + hints: [{ name: 'fast' }, { name: null }] + }, + messages: [ + { role: 'user', content: { type: 'text', text: 'hello' } }, + { + role: 'assistant', + content: { type: 'image', mimeType: 'image/jpeg', data: 'aGVsbG8=' } + } + ] + } + + const { payload, chatMessages } = (client as any).prepareSamplingContext('req-123', params) + + expect(payload).toEqual({ + requestId: 'req-123', + serverName: 'server-one', + serverLabel: 'Sample server', + systemPrompt: 'You are a helpful assistant.', + maxTokens: 128, + modelPreferences: { + costPriority: 0.5, + hints: [{ name: 'fast' }, { name: undefined }] + }, + requiresVision: true, + messages: [ + { role: 'user', type: 'text', text: 'hello' }, + { + role: 'assistant', + type: 'image', + dataUrl: 'data:image/jpeg;base64,aGVsbG8=', + mimeType: 'image/jpeg' + } + ] + }) + + expect(chatMessages).toEqual([ + { role: 'system', content: 'You are a helpful assistant.' }, + { role: 'user', content: 'hello' }, + { + role: 'assistant', + content: [ + { + type: 'image_url', + image_url: { url: 'data:image/jpeg;base64,aGVsbG8=', detail: 'auto' } + } + ] + } + ]) + }) + + it('should return assistant response when sampling decision is approved', async () => { + const client = new McpClient('code-reviewer', { + type: 'stdio', + description: 'Code Reviewer Server' + }) + + mockHandleSamplingRequest.mockResolvedValue({ + requestId: 'rpc-001', + approved: true, + providerId: 'provider-1', + modelId: 'model-42' + }) + mockGenerateCompletionStandalone.mockResolvedValue('Generated response') + mockGetProviderModels.mockReturnValue([{ id: 'model-42', name: 'Model Forty Two' }]) + mockGetCustomModels.mockReturnValue([]) + + const request = { + params: { + maxTokens: 256, + systemPrompt: 'System context', + messages: [{ role: 'user', content: { type: 'text', text: 'Explain this change.' } }] + } + } + + const result = await (client as any).handleSamplingCreateMessage(request, { + requestId: 'rpc-001' + }) + + expect(mockHandleSamplingRequest).toHaveBeenCalledWith( + expect.objectContaining({ + requestId: 'rpc-001', + serverName: 'code-reviewer' + }) + ) + expect(mockGenerateCompletionStandalone).toHaveBeenCalledWith( + 'provider-1', + [ + { role: 'system', content: 'System context' }, + { role: 'user', content: 'Explain this change.' } + ], + 'model-42', + undefined, + 256 + ) + + expect(result).toEqual({ + role: 'assistant', + model: 'Model Forty Two', + stopReason: 'endTurn', + content: { type: 'text', text: 'Generated response' } + }) + }) + + it('should throw when sampling decision is rejected by the user', async () => { + const client = new McpClient('code-reviewer', { type: 'stdio' }) + + mockHandleSamplingRequest.mockResolvedValue({ + requestId: 'rpc-002', + approved: false + }) + + const request = { + params: { + messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }] + } + } + + let caughtError: unknown + try { + await (client as any).handleSamplingCreateMessage(request, { requestId: 'rpc-002' }) + } catch (error) { + caughtError = error + } + + expect(caughtError).toBeInstanceOf(Error) + expect((caughtError as Error).message).toContain('User rejected sampling request') + expect(caughtError).toHaveProperty('code', ErrorCode.InvalidRequest) + + expect(mockGenerateCompletionStandalone).not.toHaveBeenCalled() + }) + }) }) From d3df6275f41061d5ddf1ea201d84e49e1e7e14ac Mon Sep 17 00:00:00 2001 From: duskzhen Date: Wed, 29 Oct 2025 15:35:48 +0800 Subject: [PATCH 03/13] fix(mcp): polish sampling ui and permissions --- src/main/presenter/threadPresenter/index.ts | 28 ++++++++++++++++++- src/renderer/src/components/ModelChooser.vue | 5 +--- .../src/components/mcp/McpSamplingDialog.vue | 9 ++++-- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/main/presenter/threadPresenter/index.ts b/src/main/presenter/threadPresenter/index.ts index 451d0088c..7eeb0af77 100644 --- a/src/main/presenter/threadPresenter/index.ts +++ b/src/main/presenter/threadPresenter/index.ts @@ -187,6 +187,7 @@ export class ThreadPresenter implements IThreadPresenter { tool_call_server_description, tool_call_response_raw, tool_call, + permission_request, totalUsage, image_data } = msg @@ -388,6 +389,30 @@ export class ThreadPresenter implements IThreadPresenter { } this.finalizeLastBlock(state) + const permissionExtra: Record = { + needsUserAction: true + } + + if (permission_request?.permissionType) { + permissionExtra.permissionType = permission_request.permissionType + } + if (permission_request) { + permissionExtra.permissionRequest = JSON.stringify(permission_request) + if (permission_request.toolName) { + permissionExtra.toolName = permission_request.toolName + } + if (permission_request.serverName) { + permissionExtra.serverName = permission_request.serverName + } + } else { + if (tool_call_name) { + permissionExtra.toolName = tool_call_name + } + if (tool_call_server_name) { + permissionExtra.serverName = tool_call_server_name + } + } + state.message.content.push({ type: 'action', content: tool_call_response || '', @@ -401,7 +426,8 @@ export class ThreadPresenter implements IThreadPresenter { server_name: tool_call_server_name, server_icons: tool_call_server_icons, server_description: tool_call_server_description - } + }, + extra: permissionExtra }) this.searchingMessages.add(eventId) diff --git a/src/renderer/src/components/ModelChooser.vue b/src/renderer/src/components/ModelChooser.vue index 6a07c399f..989ac16f6 100644 --- a/src/renderer/src/components/ModelChooser.vue +++ b/src/renderer/src/components/ModelChooser.vue @@ -1,8 +1,5 @@