From 0ab2544f57a3fdcd5880a768c6ac2b7ecf67e476 Mon Sep 17 00:00:00 2001 From: Jasmeet Bhatia Date: Sun, 8 Feb 2026 18:31:28 +0000 Subject: [PATCH 1/2] feat(mcp): add MCP tool progress notifications Display real-time progress bars for MCP tool operations by wiring SDK onprogress callbacks through core events to the CLI UI layer. Closes #16934 --- .../components/messages/ToolMessage.test.tsx | 14 +- .../ui/components/messages/ToolMessage.tsx | 15 +- .../components/messages/ToolShared.test.tsx | 58 +++++ .../src/ui/components/messages/ToolShared.tsx | 50 ++++ .../__snapshots__/ToolMessage.test.tsx.snap | 10 + .../__snapshots__/ToolShared.test.tsx.snap | 32 +++ packages/cli/src/ui/hooks/toolMapping.ts | 4 + .../cli/src/ui/hooks/useMCPProgress.test.ts | 242 ++++++++++++++++++ packages/cli/src/ui/hooks/useMCPProgress.ts | 68 +++++ packages/cli/src/ui/hooks/useToolScheduler.ts | 29 ++- packages/cli/src/ui/types.ts | 2 + .../core/src/scheduler/tool-executor.test.ts | 70 +++++ packages/core/src/scheduler/tool-executor.ts | 2 + packages/core/src/scheduler/types.ts | 2 + packages/core/src/tools/mcp-client.test.ts | 79 ++++++ packages/core/src/tools/mcp-client.ts | 25 +- packages/core/src/tools/mcp-tool.test.ts | 127 ++++++++- packages/core/src/tools/mcp-tool.ts | 37 ++- packages/core/src/tools/tools.ts | 5 + packages/core/src/utils/events.ts | 36 +++ 20 files changed, 895 insertions(+), 12 deletions(-) create mode 100644 packages/cli/src/ui/components/messages/ToolShared.test.tsx create mode 100644 packages/cli/src/ui/components/messages/__snapshots__/ToolShared.test.tsx.snap create mode 100644 packages/cli/src/ui/hooks/useMCPProgress.test.ts create mode 100644 packages/cli/src/ui/hooks/useMCPProgress.ts diff --git a/packages/cli/src/ui/components/messages/ToolMessage.test.tsx b/packages/cli/src/ui/components/messages/ToolMessage.test.tsx index 29012bbd26f..4a96131e026 100644 --- a/packages/cli/src/ui/components/messages/ToolMessage.test.tsx +++ b/packages/cli/src/ui/components/messages/ToolMessage.test.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { ToolMessage, type ToolMessageProps } from './ToolMessage.js'; -import { describe, it, expect, vi } from 'vitest'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; import { StreamingState } from '../../types.js'; import { Text } from 'ink'; import { type AnsiOutput, CoreToolCallStatus } from '@google/gemini-cli-core'; @@ -299,6 +299,18 @@ describe('', () => { expect(lowEmphasisFrame()).toMatchSnapshot(); }); + it('renders MCPProgressIndicator when executing with progress', () => { + const { lastFrame } = renderWithContext( + , + StreamingState.Responding, + ); + expect(lastFrame()).toMatchSnapshot(); + }); + it('renders AnsiOutputText for AnsiOutput results', () => { const ansiResult: AnsiOutput = [ [ diff --git a/packages/cli/src/ui/components/messages/ToolMessage.tsx b/packages/cli/src/ui/components/messages/ToolMessage.tsx index 06ad6b3f7b4..6a067eada6b 100644 --- a/packages/cli/src/ui/components/messages/ToolMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolMessage.tsx @@ -13,6 +13,7 @@ import { ToolStatusIndicator, ToolInfo, TrailingIndicator, + MCPProgressIndicator, type TextEmphasis, STATUS_INDICATOR_WIDTH, isThisShellFocusable as checkIsShellFocusable, @@ -20,7 +21,7 @@ import { useFocusHint, FocusHint, } from './ToolShared.js'; -import { type Config } from '@google/gemini-cli-core'; +import { type Config, CoreToolCallStatus } from '@google/gemini-cli-core'; import { ShellInputPrompt } from '../ShellInputPrompt.js'; export type { TextEmphasis }; @@ -55,6 +56,7 @@ export const ToolMessage: React.FC = ({ embeddedShellFocused, ptyId, config, + mcpProgress, }) => { const isThisShellFocused = checkIsShellFocused( name, @@ -108,6 +110,17 @@ export const ToolMessage: React.FC = ({ paddingX={1} flexDirection="column" > + {status === CoreToolCallStatus.Executing && mcpProgress && ( + + )} { + it('renders determinate progress bar at 50%', () => { + const { lastFrame } = render( + , + ); + expect(lastFrame()).toMatchSnapshot(); + }); + + it('renders fully complete progress bar', () => { + const { lastFrame } = render( + , + ); + expect(lastFrame()).toMatchSnapshot(); + }); + + it('renders indeterminate progress without total', () => { + const { lastFrame } = render( + , + ); + expect(lastFrame()).toMatchSnapshot(); + }); + + it('renders progress message when provided', () => { + const { lastFrame } = render( + , + ); + expect(lastFrame()).toMatchSnapshot(); + }); + + it('scales bar width correctly', () => { + const { lastFrame } = render( + , + ); + expect(lastFrame()).toMatchSnapshot(); + }); + + it('clamps progress exceeding total', () => { + const { lastFrame } = render( + , + ); + expect(lastFrame()).toMatchSnapshot(); + }); +}); diff --git a/packages/cli/src/ui/components/messages/ToolShared.tsx b/packages/cli/src/ui/components/messages/ToolShared.tsx index fc0e546cc99..4cf9de92701 100644 --- a/packages/cli/src/ui/components/messages/ToolShared.tsx +++ b/packages/cli/src/ui/components/messages/ToolShared.tsx @@ -233,3 +233,53 @@ export const TrailingIndicator: React.FC = () => ( ← ); + +export interface MCPProgressIndicatorProps { + progress: number; + total?: number; + message?: string; + barWidth: number; +} + +/** + * Values are clamped to prevent crashes from negative repeat counts + * when progress > total (which can happen with misbehaving MCP servers). + */ +export const MCPProgressIndicator: React.FC = ({ + progress, + total, + message, + barWidth, +}) => { + const percentage = + total && total > 0 ? Math.round((progress / total) * 100) : null; + + let rawFilled: number; + if (total && total > 0) { + rawFilled = Math.round((progress / total) * barWidth); + } else { + rawFilled = Math.floor(progress) % (barWidth + 1); + } + + const filled = Math.max( + 0, + Math.min(Number.isFinite(rawFilled) ? rawFilled : 0, barWidth), + ); + const empty = Math.max(0, barWidth - filled); + const progressBar = '\u2588'.repeat(filled) + '\u2591'.repeat(empty); + + return ( + + + + {progressBar} {percentage !== null ? `${percentage}%` : `${progress}`} + + + {message && ( + + {message} + + )} + + ); +}; diff --git a/packages/cli/src/ui/components/messages/__snapshots__/ToolMessage.test.tsx.snap b/packages/cli/src/ui/components/messages/__snapshots__/ToolMessage.test.tsx.snap index a3fedd751b1..a24d7341ba4 100644 --- a/packages/cli/src/ui/components/messages/__snapshots__/ToolMessage.test.tsx.snap +++ b/packages/cli/src/ui/components/messages/__snapshots__/ToolMessage.test.tsx.snap @@ -81,6 +81,16 @@ exports[` > renders DiffRenderer for diff results 1`] = ` │ 1 + new │" `; +exports[` > renders MCPProgressIndicator when executing with progress 1`] = ` +"╭──────────────────────────────────────────────────────────────────────────────╮ +│ ⊶ test-tool A tool for testing │ +│ │ +│ │ +│ ████████████████░░░░░░░░░░░░░░░░ 50% │ +│ Processing... │ +│ Test result │" +`; + exports[` > renders basic tool information 1`] = ` "╭──────────────────────────────────────────────────────────────────────────────╮ │ ✓ test-tool A tool for testing │ diff --git a/packages/cli/src/ui/components/messages/__snapshots__/ToolShared.test.tsx.snap b/packages/cli/src/ui/components/messages/__snapshots__/ToolShared.test.tsx.snap new file mode 100644 index 00000000000..2a0ed01b0ef --- /dev/null +++ b/packages/cli/src/ui/components/messages/__snapshots__/ToolShared.test.tsx.snap @@ -0,0 +1,32 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`MCPProgressIndicator > clamps progress exceeding total 1`] = ` +" +████████████████████ 150%" +`; + +exports[`MCPProgressIndicator > renders determinate progress bar at 50% 1`] = ` +" +██████████░░░░░░░░░░ 50%" +`; + +exports[`MCPProgressIndicator > renders fully complete progress bar 1`] = ` +" +████████████████████ 100%" +`; + +exports[`MCPProgressIndicator > renders indeterminate progress without total 1`] = ` +" +█████░░░░░░░░░░░░░░░ 5" +`; + +exports[`MCPProgressIndicator > renders progress message when provided 1`] = ` +" +██████████░░░░░░░░░░ 50% +Downloading..." +`; + +exports[`MCPProgressIndicator > scales bar width correctly 1`] = ` +" +████████████████████░░░░░░░░░░░░░░░░░░░░ 50%" +`; diff --git a/packages/cli/src/ui/hooks/toolMapping.ts b/packages/cli/src/ui/hooks/toolMapping.ts index d921651e514..91e3f8263fd 100644 --- a/packages/cli/src/ui/hooks/toolMapping.ts +++ b/packages/cli/src/ui/hooks/toolMapping.ts @@ -11,6 +11,7 @@ import { debugLogger, CoreToolCallStatus, } from '@google/gemini-cli-core'; +import type { Progress } from '@modelcontextprotocol/sdk/types.js'; import { type HistoryItemToolGroup, type IndividualToolCallDisplay, @@ -54,6 +55,7 @@ export function mapToDisplay( let outputFile: string | undefined = undefined; let ptyId: number | undefined = undefined; let correlationId: string | undefined = undefined; + let mcpProgress: Progress | undefined = undefined; switch (call.status) { case CoreToolCallStatus.Success: @@ -72,6 +74,7 @@ export function mapToDisplay( case CoreToolCallStatus.Executing: resultDisplay = call.liveOutput; ptyId = call.pid; + mcpProgress = call.mcpProgress; break; case CoreToolCallStatus.Scheduled: case CoreToolCallStatus.Validating: @@ -96,6 +99,7 @@ export function mapToDisplay( ptyId, correlationId, approvalMode: call.approvalMode, + mcpProgress, }; }); diff --git a/packages/cli/src/ui/hooks/useMCPProgress.test.ts b/packages/cli/src/ui/hooks/useMCPProgress.test.ts new file mode 100644 index 00000000000..4036bd00794 --- /dev/null +++ b/packages/cli/src/ui/hooks/useMCPProgress.test.ts @@ -0,0 +1,242 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { renderHook } from '../../test-utils/render.js'; +import { act } from 'react'; +import { useMCPProgress } from './useMCPProgress.js'; +import { coreEvents, CoreEvent } from '@google/gemini-cli-core'; + +describe('useMCPProgress', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + afterEach(() => { + // Clean up any lingering listeners + }); + + it('should initialize with empty state', () => { + const { result } = renderHook(() => useMCPProgress()); + expect(result.current.progressState).toEqual({}); + }); + + it('should update state when progress event is received', () => { + const { result } = renderHook(() => useMCPProgress()); + + act(() => { + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-1', + serverName: 'server', + toolName: 'tool', + progress: 50, + total: 100, + message: 'Processing...', + }); + }); + + expect(result.current.progressState).toEqual({ + 'call-1': { + progress: 50, + total: 100, + message: 'Processing...', + }, + }); + }); + + it('should track multiple concurrent tool calls', () => { + const { result } = renderHook(() => useMCPProgress()); + + act(() => { + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-1', + serverName: 'server1', + toolName: 'tool1', + progress: 25, + }); + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-2', + serverName: 'server2', + toolName: 'tool2', + progress: 75, + total: 100, + }); + }); + + expect(result.current.progressState).toEqual({ + 'call-1': { progress: 25, total: undefined, message: undefined }, + 'call-2': { progress: 75, total: 100, message: undefined }, + }); + }); + + it('should update existing progress for same callId', () => { + const { result } = renderHook(() => useMCPProgress()); + + act(() => { + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-1', + serverName: 'server', + toolName: 'tool', + progress: 25, + }); + }); + + act(() => { + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-1', + serverName: 'server', + toolName: 'tool', + progress: 75, + message: 'Almost done', + }); + }); + + expect(result.current.progressState['call-1']).toEqual({ + progress: 75, + total: undefined, + message: 'Almost done', + }); + }); + + it('should clear progress for specific callId', () => { + const { result } = renderHook(() => useMCPProgress()); + + act(() => { + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-1', + serverName: 'server', + toolName: 'tool', + progress: 50, + }); + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-2', + serverName: 'server', + toolName: 'tool', + progress: 75, + }); + }); + + act(() => { + result.current.clearProgress('call-1'); + }); + + expect(result.current.progressState).toEqual({ + 'call-2': { progress: 75, total: undefined, message: undefined }, + }); + }); + + it('should clear all progress', () => { + const { result } = renderHook(() => useMCPProgress()); + + act(() => { + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-1', + serverName: 'server', + toolName: 'tool', + progress: 50, + }); + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-2', + serverName: 'server', + toolName: 'tool', + progress: 75, + }); + }); + + act(() => { + result.current.clearAllProgress(); + }); + + expect(result.current.progressState).toEqual({}); + }); + + it('should reject late progress events for completed callIds', () => { + const { result } = renderHook(() => useMCPProgress()); + + // Emit progress for a call + act(() => { + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-1', + serverName: 'server', + toolName: 'tool', + progress: 50, + }); + }); + + // Clear progress (simulating tool reaching terminal state) + act(() => { + result.current.clearProgress('call-1'); + }); + + expect(result.current.progressState).toEqual({}); + + // Late progress event for same callId should be ignored + act(() => { + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-1', + serverName: 'server', + toolName: 'tool', + progress: 75, + }); + }); + + // Should still be empty — late event rejected + expect(result.current.progressState).toEqual({}); + }); + + it('should allow progress for previously-completed callId after clearAllProgress', () => { + const { result } = renderHook(() => useMCPProgress()); + + // Complete a call + act(() => { + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-1', + serverName: 'server', + toolName: 'tool', + progress: 100, + }); + result.current.clearProgress('call-1'); + }); + + // clearAllProgress resets completed tracking (new schedule) + act(() => { + result.current.clearAllProgress(); + }); + + // Same callId should now be accepted again (fresh schedule) + act(() => { + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'call-1', + serverName: 'server', + toolName: 'tool', + progress: 25, + }); + }); + + expect(result.current.progressState['call-1']).toEqual({ + progress: 25, + total: undefined, + message: undefined, + }); + }); + + it('should unsubscribe from events on unmount', () => { + const listenersBefore = coreEvents.listenerCount(CoreEvent.MCPToolProgress); + const { unmount } = renderHook(() => useMCPProgress()); + + // Hook should have added a listener + expect(coreEvents.listenerCount(CoreEvent.MCPToolProgress)).toBe( + listenersBefore + 1, + ); + + unmount(); + + // Listener should be removed after unmount + expect(coreEvents.listenerCount(CoreEvent.MCPToolProgress)).toBe( + listenersBefore, + ); + }); +}); diff --git a/packages/cli/src/ui/hooks/useMCPProgress.ts b/packages/cli/src/ui/hooks/useMCPProgress.ts new file mode 100644 index 00000000000..17466749415 --- /dev/null +++ b/packages/cli/src/ui/hooks/useMCPProgress.ts @@ -0,0 +1,68 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useState, useEffect, useCallback, useRef } from 'react'; +import { + coreEvents, + CoreEvent, + type MCPToolProgressPayload, +} from '@google/gemini-cli-core'; +import type { Progress } from '@modelcontextprotocol/sdk/types.js'; + +export interface MCPProgressState { + [callId: string]: Progress; +} + +/** + * Hook to track MCP tool progress updates. + * Subscribes to CoreEvent.MCPToolProgress and maintains state per callId. + */ +export const useMCPProgress = () => { + const [progressState, setProgressState] = useState({}); + // Track completed callIds to reject late-arriving progress events. + // Without this, a progress event arriving after clearProgress() would + // reintroduce a stale entry into progressState. + const completedCallIds = useRef>(new Set()); + + useEffect(() => { + const handleProgress = (payload: MCPToolProgressPayload) => { + // Reject progress for calls that have already reached a terminal state + if (completedCallIds.current.has(payload.callId)) { + return; + } + setProgressState((prev) => ({ + ...prev, + [payload.callId]: { + progress: payload.progress, + total: payload.total, + message: payload.message, + }, + })); + }; + + coreEvents.on(CoreEvent.MCPToolProgress, handleProgress); + return () => { + coreEvents.off(CoreEvent.MCPToolProgress, handleProgress); + }; + }, []); + + const clearProgress = useCallback((callId: string) => { + completedCallIds.current.add(callId); + setProgressState((prev) => { + if (!(callId in prev)) return prev; // No state to clear — avoid re-render + const { [callId]: _, ...rest } = prev; + return rest; + }); + }, []); + + const clearAllProgress = useCallback(() => { + // Reset completed tracking when starting fresh (new schedule) + completedCallIds.current.clear(); + setProgressState({}); + }, []); + + return { progressState, clearProgress, clearAllProgress }; +}; diff --git a/packages/cli/src/ui/hooks/useToolScheduler.ts b/packages/cli/src/ui/hooks/useToolScheduler.ts index 89bee143420..d8d675f15a9 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.ts @@ -16,6 +16,8 @@ import { type ToolCallsUpdateMessage, } from '@google/gemini-cli-core'; import { useCallback, useState, useMemo, useEffect, useRef } from 'react'; +import type { Progress } from '@modelcontextprotocol/sdk/types.js'; +import { useMCPProgress } from './useMCPProgress.js'; // Re-exporting types compatible with hook expectations export type ScheduleFn = ( @@ -32,6 +34,7 @@ export type CancelAllFn = (signal: AbortSignal) => void; */ export type TrackedToolCall = ToolCall & { responseSubmittedToGemini?: boolean; + mcpProgress?: Progress; }; // Narrowed types for specific statuses (used by useGeminiStream) @@ -81,6 +84,8 @@ export function useToolScheduler( >({}); const [lastToolOutputTime, setLastToolOutputTime] = useState(0); + const { progressState, clearProgress, clearAllProgress } = useMCPProgress(); + const messageBus = useMemo(() => config.getMessageBus(), [config]); const onCompleteRef = useRef(onComplete); @@ -117,6 +122,12 @@ export function useToolScheduler( setLastToolOutputTime(Date.now()); } + event.toolCalls.forEach((tc) => { + if (['success', 'cancelled', 'error'].includes(tc.status)) { + clearProgress(tc.request.callId); + } + }); + setToolCallsMap((prev) => { const adapted = internalAdaptToolCalls( event.toolCalls, @@ -134,12 +145,13 @@ export function useToolScheduler( return () => { messageBus.unsubscribe(MessageBusType.TOOL_CALLS_UPDATE, handler); }; - }, [messageBus, internalAdaptToolCalls]); + }, [messageBus, internalAdaptToolCalls, clearProgress]); const schedule: ScheduleFn = useCallback( async (request, signal) => { // Clear state for new run setToolCallsMap({}); + clearAllProgress(); // 1. Await Core Scheduler directly const results = await scheduler.schedule(request, signal); @@ -151,7 +163,7 @@ export function useToolScheduler( return results; }, - [scheduler], + [scheduler, clearAllProgress], ); const cancelAll: CancelAllFn = useCallback( @@ -184,6 +196,17 @@ export function useToolScheduler( [toolCallsMap], ); + const toolCallsWithProgress = useMemo( + () => + toolCalls.map((tc): TrackedToolCall => { + if (tc.status === 'executing' && progressState[tc.request.callId]) { + return { ...tc, mcpProgress: progressState[tc.request.callId] }; + } + return tc; + }), + [toolCalls, progressState], + ); + // Provide a setter that maintains compatibility with legacy []. const setToolCallsForDisplay = useCallback( (action: React.SetStateAction) => { @@ -214,7 +237,7 @@ export function useToolScheduler( ); return [ - toolCalls, + toolCallsWithProgress, schedule, markToolsAsSubmitted, setToolCallsForDisplay, diff --git a/packages/cli/src/ui/types.ts b/packages/cli/src/ui/types.ts index 8481cca71f2..ea14c6c0470 100644 --- a/packages/cli/src/ui/types.ts +++ b/packages/cli/src/ui/types.ts @@ -18,6 +18,7 @@ import { CoreToolCallStatus, checkExhaustive, } from '@google/gemini-cli-core'; +import type { Progress } from '@modelcontextprotocol/sdk/types.js'; import type { PartListUnion } from '@google/genai'; import { type ReactNode } from 'react'; @@ -108,6 +109,7 @@ export interface IndividualToolCallDisplay { outputFile?: string; correlationId?: string; approvalMode?: ApprovalMode; + mcpProgress?: Progress; } export interface CompressionProps { diff --git a/packages/core/src/scheduler/tool-executor.test.ts b/packages/core/src/scheduler/tool-executor.test.ts index 53b244031db..51b81e38b0a 100644 --- a/packages/core/src/scheduler/tool-executor.test.ts +++ b/packages/core/src/scheduler/tool-executor.test.ts @@ -293,4 +293,74 @@ describe('ToolExecutor', () => { }), ); }); + + describe('setCallId for MCP progress', () => { + it('should call setCallId on invocation before execution when method exists', async () => { + const mockSetCallId = vi.fn(); + const mockTool = new MockTool({ name: 'mcp-tool' }); + const invocation = mockTool.build({}); + invocation.setCallId = mockSetCallId; + + // Mock executeToolWithHooks to return success + vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockResolvedValue({ + llmContent: 'done', + returnDisplay: 'done', + } as ToolResult); + + const scheduledCall: ScheduledToolCall = { + status: CoreToolCallStatus.Scheduled, + request: { + callId: 'test-call-123', + name: 'mcp-tool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-setcallid-1', + }, + tool: mockTool, + invocation, + startTime: Date.now(), + }; + + await executor.execute({ + call: scheduledCall, + signal: new AbortController().signal, + onUpdateToolCall: vi.fn(), + }); + + expect(mockSetCallId).toHaveBeenCalledWith('test-call-123'); + }); + + it('should not fail when invocation lacks setCallId method', async () => { + const mockTool = new MockTool({ name: 'shell-tool' }); + const invocation = mockTool.build({}); + + vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockResolvedValue({ + llmContent: 'done', + returnDisplay: 'done', + } as ToolResult); + + const scheduledCall: ScheduledToolCall = { + status: CoreToolCallStatus.Scheduled, + request: { + callId: 'test-call-456', + name: 'shell-tool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-setcallid-2', + }, + tool: mockTool, + invocation, + startTime: Date.now(), + }; + + // Should not throw + await expect( + executor.execute({ + call: scheduledCall, + signal: new AbortController().signal, + onUpdateToolCall: vi.fn(), + }), + ).resolves.toBeDefined(); + }); + }); }); diff --git a/packages/core/src/scheduler/tool-executor.ts b/packages/core/src/scheduler/tool-executor.ts index 116598a2b95..b6e362ad00f 100644 --- a/packages/core/src/scheduler/tool-executor.ts +++ b/packages/core/src/scheduler/tool-executor.ts @@ -58,6 +58,8 @@ export class ToolExecutor { } const { tool, invocation } = call; + invocation?.setCallId?.(callId); + // Setup live output handling const liveOutputCallback = tool.canUpdateOutput && outputUpdateHandler diff --git a/packages/core/src/scheduler/types.ts b/packages/core/src/scheduler/types.ts index b09c42fe514..9eb4d7aac6a 100644 --- a/packages/core/src/scheduler/types.ts +++ b/packages/core/src/scheduler/types.ts @@ -16,6 +16,7 @@ import type { AnsiOutput } from '../utils/terminalSerializer.js'; import type { ToolErrorType } from '../tools/tool-error.js'; import type { SerializableConfirmationDetails } from '../confirmation-bus/types.js'; import { type ApprovalMode } from '../policy/types.js'; +import type { Progress } from '@modelcontextprotocol/sdk/types.js'; export const ROOT_SCHEDULER_ID = 'root'; @@ -114,6 +115,7 @@ export type ExecutingToolCall = { pid?: number; schedulerId?: string; approvalMode?: ApprovalMode; + mcpProgress?: Progress; }; export type CancelledToolCall = { diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 3f289f17322..5b41631e2fe 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -25,11 +25,15 @@ import { WorkspaceContext } from '../utils/workspaceContext.js'; import { connectToMcpServer, createTransport, + discoverTools, hasNetworkTransport, isEnabled, McpClient, populateMcpServerCommand, } from './mcp-client.js'; +import type { DiscoveredMCPToolInvocation } from './mcp-tool.js'; +import { MCPServerConfig } from '../config/config.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { ToolRegistry } from './tool-registry.js'; import type { ResourceRegistry } from '../resources/resource-registry.js'; import * as fs from 'node:fs'; @@ -57,6 +61,7 @@ vi.mock('../utils/events.js', () => ({ coreEvents: { emitFeedback: vi.fn(), emitConsoleLog: vi.fn(), + emitMCPToolProgress: vi.fn(), }, })); @@ -2265,3 +2270,77 @@ describe('connectToMcpServer - OAuth with transport fallback', () => { expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); }); }); + +describe('McpCallableTool SDK options forwarding (integration)', () => { + let mockSdkCallTool: ReturnType; + let tools: Awaited>; + + beforeEach(async () => { + mockSdkCallTool = vi.fn().mockResolvedValue({ + content: [{ type: 'text', text: 'done' }], + }); + + const mockedSdkClient = { + getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { type: 'object', properties: {} }, + }, + ], + }), + callTool: mockSdkCallTool, + }; + + const mockConfig = { + getPolicyEngine: vi.fn().mockReturnValue({ addRule: vi.fn() }), + sanitizationConfig: EMPTY_CONFIG, + } as unknown as Config; + + const mockMessageBus = { + publish: vi.fn(), + on: vi.fn(), + off: vi.fn(), + emit: vi.fn(), + } as unknown as MessageBus; + + tools = await discoverTools( + 'test-server', + new MCPServerConfig('test-cmd'), + mockedSdkClient as unknown as ClientLib.Client, + mockConfig, + mockMessageBus, + ); + + expect(tools).toHaveLength(1); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should forward onprogress to SDK client.callTool when invocation has callId', async () => { + const invocation = tools[0].build({}) as DiscoveredMCPToolInvocation; + invocation.setCallId('call-123'); + await invocation.execute(new AbortController().signal); + + expect(mockSdkCallTool).toHaveBeenCalledWith( + { name: 'test-tool', arguments: {} }, + undefined, + expect.objectContaining({ + onprogress: expect.any(Function), + resetTimeoutOnProgress: true, + }), + ); + }); + + it('should not pass onprogress when callId is not set', async () => { + const invocation = tools[0].build({}); + await invocation.execute(new AbortController().signal); + + const callOptions = mockSdkCallTool.mock.calls[0][2]; + expect(callOptions.onprogress).toBeUndefined(); + }); +}); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 7902d8953a4..e9bbff3897a 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -22,6 +22,7 @@ import type { Prompt, ReadResourceResult, Resource, + Progress, } from '@modelcontextprotocol/sdk/types.js'; import { ListResourcesResultSchema, @@ -1073,7 +1074,18 @@ export async function discoverTools( } } -class McpCallableTool implements CallableTool { +/** + * Extended CallableTool interface that supports progress reporting. + * Used by MCP tools that can emit progress updates. + */ +export interface CallableToolWithProgress extends CallableTool { + callTool( + functionCalls: FunctionCall[], + progressCallback?: (progress: Progress) => void, + ): Promise; +} + +class McpCallableTool implements CallableToolWithProgress { constructor( private readonly client: Client, private readonly toolDef: McpTool, @@ -1092,7 +1104,10 @@ class McpCallableTool implements CallableTool { }; } - async callTool(functionCalls: FunctionCall[]): Promise { + async callTool( + functionCalls: FunctionCall[], + progressCallback?: (progress: Progress) => void, + ): Promise { // We only expect one function call at a time for MCP tools in this context if (functionCalls.length !== 1) { throw new Error('McpCallableTool only supports single function call'); @@ -1107,7 +1122,11 @@ class McpCallableTool implements CallableTool { arguments: call.args as Record, }, undefined, - { timeout: this.timeout }, + { + timeout: this.timeout, + onprogress: progressCallback, + resetTimeoutOnProgress: true, + }, ); return [ diff --git a/packages/core/src/tools/mcp-tool.test.ts b/packages/core/src/tools/mcp-tool.test.ts index 4cdad898274..fc1cfcc8906 100644 --- a/packages/core/src/tools/mcp-tool.test.ts +++ b/packages/core/src/tools/mcp-tool.test.ts @@ -8,15 +8,21 @@ import type { Mocked } from 'vitest'; import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { safeJsonStringify } from '../utils/safeJsonStringify.js'; -import { DiscoveredMCPTool, generateValidName } from './mcp-tool.js'; // Added getStringifiedResultForDisplay +import { + DiscoveredMCPTool, + type DiscoveredMCPToolInvocation, + generateValidName, +} from './mcp-tool.js'; import type { ToolResult } from './tools.js'; import { ToolConfirmationOutcome } from './tools.js'; // Added ToolConfirmationOutcome -import type { CallableTool, Part } from '@google/genai'; +import type { CallableTool, FunctionCall, Part } from '@google/genai'; +import type { Progress } from '@modelcontextprotocol/sdk/types.js'; import { ToolErrorType } from './tool-error.js'; import { createMockMessageBus, getMockMessageBusInstance, } from '../test-utils/mock-message-bus.js'; +import { coreEvents, CoreEvent } from '../utils/events.js'; // Mock @google/genai mcpToTool and CallableTool // We only need to mock the parts of CallableTool that DiscoveredMCPTool uses. @@ -932,4 +938,121 @@ describe('DiscoveredMCPTool', () => { expect(description).toBe('{"param":"testValue","param2":"anotherOne"}'); }); }); + + describe('DiscoveredMCPToolInvocation progress', () => { + let progressMockCallTool: ReturnType; + let mockBus: ReturnType; + + beforeEach(() => { + progressMockCallTool = vi.fn(); + mockBus = createMockMessageBus(); + }); + + function buildInvocation(callId?: string) { + const mockCallableToolInst = { + tool: vi.fn(), + callTool: progressMockCallTool, + } as unknown as Mocked; + + const testTool = new DiscoveredMCPTool( + mockCallableToolInst, + 'test-server', + 'test-tool', + 'A test tool', + { type: 'object', properties: {} }, + mockBus, + false, + false, + ); + + const invocation = testTool.build({ + param: 'value', + }) as DiscoveredMCPToolInvocation; + if (callId) { + invocation.setCallId(callId); + } + return invocation; + } + + it('should pass progressCallback as second arg to callTool when callId is set', async () => { + progressMockCallTool.mockResolvedValue( + createSdkResponse('test-tool', { + content: [{ type: 'text', text: 'done' }], + }), + ); + + await buildInvocation('test-call-123').execute( + new AbortController().signal, + ); + + expect(progressMockCallTool).toHaveBeenCalledWith( + expect.any(Array), + expect.any(Function), // progressCallback + ); + }); + + it('should not pass progressCallback when callId is not set', async () => { + progressMockCallTool.mockResolvedValue( + createSdkResponse('test-tool', { + content: [{ type: 'text', text: 'done' }], + }), + ); + + await buildInvocation().execute(new AbortController().signal); + + expect(progressMockCallTool).toHaveBeenCalledWith( + expect.any(Array), // functionCalls only — no second arg + ); + }); + + it('should emit progress events via coreEvents when callId is set', async () => { + const handler = vi.fn(); + coreEvents.on(CoreEvent.MCPToolProgress, handler); + + progressMockCallTool.mockImplementation( + async (_calls: FunctionCall[], progressCb?: (p: Progress) => void) => { + if (progressCb) { + progressCb({ progress: 25, total: 100, message: 'Starting...' }); + progressCb({ progress: 50, total: 100, message: 'Halfway...' }); + progressCb({ progress: 100, total: 100, message: 'Done!' }); + } + return [ + { + functionResponse: { name: 'test', response: { content: [] } }, + }, + ]; + }, + ); + + await buildInvocation('test-call-123').execute( + new AbortController().signal, + ); + + expect(handler).toHaveBeenCalledTimes(3); + expect(handler).toHaveBeenNthCalledWith(1, { + callId: 'test-call-123', + serverName: 'test-server', + toolName: 'test-tool', + progress: 25, + total: 100, + message: 'Starting...', + }); + + coreEvents.off(CoreEvent.MCPToolProgress, handler); + }); + + it('should not emit progress events when callId is not set', async () => { + const handler = vi.fn(); + coreEvents.on(CoreEvent.MCPToolProgress, handler); + + progressMockCallTool.mockResolvedValue([ + { functionResponse: { name: 'test', response: { content: [] } } }, + ]); + + await buildInvocation().execute(new AbortController().signal); + + expect(handler).not.toHaveBeenCalled(); + coreEvents.off(CoreEvent.MCPToolProgress, handler); + }); + }); }); diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index c4d7a320384..6876cc528ec 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -19,9 +19,12 @@ import { type PolicyUpdateOptions, } from './tools.js'; import type { CallableTool, FunctionCall, Part } from '@google/genai'; +import type { Progress } from '@modelcontextprotocol/sdk/types.js'; import { ToolErrorType } from './tool-error.js'; import type { Config } from '../config/config.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { coreEvents } from '../utils/events.js'; +import type { CallableToolWithProgress } from './mcp-client.js'; /** * The separator used to qualify MCP tool names with their server prefix. @@ -70,6 +73,15 @@ export class DiscoveredMCPToolInvocation extends BaseToolInvocation< ToolResult > { private static readonly allowlist: Set = new Set(); + private _callId?: string; + + setCallId(callId: string): void { + this._callId = callId; + } + + get callId(): string | undefined { + return this._callId; + } constructor( private readonly mcpTool: CallableTool, @@ -174,6 +186,19 @@ export class DiscoveredMCPToolInvocation extends BaseToolInvocation< }, ]; + const progressCallback = this._callId + ? (progress: Progress) => { + coreEvents.emitMCPToolProgress({ + callId: this._callId!, + serverName: this.serverName, + toolName: this.serverToolName, + progress: progress.progress, + total: progress.total, + message: progress.message, + }); + } + : undefined; + // Race MCP tool call with abort signal to respect cancellation const rawResponseParts = await new Promise((resolve, reject) => { if (signal.aborted) { @@ -193,8 +218,16 @@ export class DiscoveredMCPToolInvocation extends BaseToolInvocation< }; signal.addEventListener('abort', onAbort, { once: true }); - this.mcpTool - .callTool(functionCalls) + // Conditionally pass progressCallback to avoid passing undefined as + // a second arg (which would change the call signature for existing tests) + const callPromise = progressCallback + ? (this.mcpTool as CallableToolWithProgress).callTool( + functionCalls, + progressCallback, + ) + : this.mcpTool.callTool(functionCalls); + + callPromise .then((res) => { cleanup(); resolve(res); diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 3d90e80699f..8c53bed846c 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -67,6 +67,11 @@ export interface ToolInvocation< updateOutput?: (output: string | AnsiOutput) => void, shellExecutionConfig?: ShellExecutionConfig, ): Promise; + + /** + * Sets a correlation ID for progress tracking. + */ + setCallId?(callId: string): void; } /** diff --git a/packages/core/src/utils/events.ts b/packages/core/src/utils/events.ts index 014c2eec7a8..1c58f341bde 100644 --- a/packages/core/src/utils/events.ts +++ b/packages/core/src/utils/events.ts @@ -13,6 +13,7 @@ import type { TokenStorageInitializationEvent, KeychainAvailabilityEvent, } from '../telemetry/types.js'; +import { debugLogger } from './debugLogger.js'; /** * Defines the severity level for user-facing feedback. @@ -151,6 +152,24 @@ export interface QuotaChangedPayload { resetTime?: string; } +/** + * Payload for the 'mcp-tool-progress' event. + */ +export interface MCPToolProgressPayload { + /** The unique identifier for this tool call */ + callId: string; + /** The name of the MCP server */ + serverName: string; + /** The name of the tool being executed */ + toolName: string; + /** Current progress value (must be non-negative) */ + progress: number; + /** Total value for percentage calculation (optional) */ + total?: number; + /** Human-readable progress message (optional) */ + message?: string; +} + export enum CoreEvent { UserFeedback = 'user-feedback', ModelChanged = 'model-changed', @@ -174,6 +193,7 @@ export enum CoreEvent { QuotaChanged = 'quota-changed', TelemetryKeychainAvailability = 'telemetry-keychain-availability', TelemetryTokenStorageType = 'telemetry-token-storage-type', + MCPToolProgress = 'mcp-tool-progress', } /** @@ -206,6 +226,7 @@ export interface CoreEvents extends ExtensionEvents { [CoreEvent.SlashCommandConflicts]: [SlashCommandConflictsPayload]; [CoreEvent.TelemetryKeychainAvailability]: [KeychainAvailabilityEvent]; [CoreEvent.TelemetryTokenStorageType]: [TokenStorageInitializationEvent]; + [CoreEvent.MCPToolProgress]: [MCPToolProgressPayload]; } type EventBacklogItem = { @@ -360,6 +381,21 @@ export class CoreEventEmitter extends EventEmitter { this.emit(CoreEvent.QuotaChanged, payload); } + /** + * Notifies subscribers of MCP tool progress updates. + * Uses direct emit (not _emitOrQueue) because progress events are: + * - High-frequency and transient (not worth queuing) + * - Disposable if no listener exists (no backlog needed) + * - A risk to the shared backlog if queued at high volume + */ + emitMCPToolProgress(payload: MCPToolProgressPayload): void { + if (!Number.isFinite(payload.progress) || payload.progress < 0) { + debugLogger.log(`Invalid progress value: ${payload.progress}`); + return; + } + this.emit(CoreEvent.MCPToolProgress, payload); + } + /** * Flushes buffered messages. Call this immediately after primary UI listener * subscribes. From 79a33f6e3b5afb6ac264cbd320b3d7aa4ea04817 Mon Sep 17 00:00:00 2001 From: Jasmeet Bhatia Date: Tue, 17 Feb 2026 22:30:43 +0000 Subject: [PATCH 2/2] refactor(mcp): move progress state management from CLI to Core Move MCP progress tracking into StateManager/Scheduler so all clients get progress via TOOL_CALLS_UPDATE. Delete useMCPProgress hook, fix CallableToolWithProgress cast, cap percentage at 100%. --- .../cli/src/ui/commands/mcpCommand.test.ts | 5 +- .../src/ui/components/messages/ToolShared.tsx | 4 +- .../__snapshots__/ToolMessage.test.tsx.snap | 2 +- .../__snapshots__/ToolShared.test.tsx.snap | 2 +- .../cli/src/ui/hooks/useMCPProgress.test.ts | 242 ------------------ packages/cli/src/ui/hooks/useMCPProgress.ts | 68 ----- packages/cli/src/ui/hooks/useToolScheduler.ts | 29 +-- .../core/src/agents/local-executor.test.ts | 4 +- .../core/src/core/coreToolScheduler.test.ts | 4 +- packages/core/src/core/prompts.test.ts | 6 +- packages/core/src/scheduler/scheduler.test.ts | 126 +++++++++ packages/core/src/scheduler/scheduler.ts | 109 ++++---- .../core/src/scheduler/state-manager.test.ts | 119 +++++++++ packages/core/src/scheduler/state-manager.ts | 39 +++ packages/core/src/telemetry/loggers.test.ts | 4 +- packages/core/src/tools/mcp-tool.test.ts | 11 +- packages/core/src/tools/mcp-tool.ts | 11 +- packages/core/src/tools/tool-registry.test.ts | 8 +- 18 files changed, 383 insertions(+), 410 deletions(-) delete mode 100644 packages/cli/src/ui/hooks/useMCPProgress.test.ts delete mode 100644 packages/cli/src/ui/hooks/useMCPProgress.ts diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts index ecce5c9cd5d..afaaefa0ac5 100644 --- a/packages/cli/src/ui/commands/mcpCommand.test.ts +++ b/packages/cli/src/ui/commands/mcpCommand.test.ts @@ -14,9 +14,8 @@ import { getMCPDiscoveryState, DiscoveredMCPTool, type MessageBus, + type CallableToolWithProgress, } from '@google/gemini-cli-core'; - -import type { CallableTool } from '@google/genai'; import { MessageType } from '../types.js'; vi.mock('@google/gemini-cli-core', async (importOriginal) => { @@ -53,7 +52,7 @@ const createMockMCPTool = ( { callTool: vi.fn(), tool: vi.fn(), - } as unknown as CallableTool, + } as unknown as CallableToolWithProgress, serverName, name, description || 'Mock tool description', diff --git a/packages/cli/src/ui/components/messages/ToolShared.tsx b/packages/cli/src/ui/components/messages/ToolShared.tsx index 4cf9de92701..cf9549a9e6b 100644 --- a/packages/cli/src/ui/components/messages/ToolShared.tsx +++ b/packages/cli/src/ui/components/messages/ToolShared.tsx @@ -252,7 +252,9 @@ export const MCPProgressIndicator: React.FC = ({ barWidth, }) => { const percentage = - total && total > 0 ? Math.round((progress / total) * 100) : null; + total && total > 0 + ? Math.min(100, Math.round((progress / total) * 100)) + : null; let rawFilled: number; if (total && total > 0) { diff --git a/packages/cli/src/ui/components/messages/__snapshots__/ToolMessage.test.tsx.snap b/packages/cli/src/ui/components/messages/__snapshots__/ToolMessage.test.tsx.snap index a24d7341ba4..73284200c5c 100644 --- a/packages/cli/src/ui/components/messages/__snapshots__/ToolMessage.test.tsx.snap +++ b/packages/cli/src/ui/components/messages/__snapshots__/ToolMessage.test.tsx.snap @@ -83,7 +83,7 @@ exports[` > renders DiffRenderer for diff results 1`] = ` exports[` > renders MCPProgressIndicator when executing with progress 1`] = ` "╭──────────────────────────────────────────────────────────────────────────────╮ -│ ⊶ test-tool A tool for testing │ +│ MockRespondingSpinnertest-tool A tool for testing │ │ │ │ │ │ ████████████████░░░░░░░░░░░░░░░░ 50% │ diff --git a/packages/cli/src/ui/components/messages/__snapshots__/ToolShared.test.tsx.snap b/packages/cli/src/ui/components/messages/__snapshots__/ToolShared.test.tsx.snap index 2a0ed01b0ef..48acf16e2db 100644 --- a/packages/cli/src/ui/components/messages/__snapshots__/ToolShared.test.tsx.snap +++ b/packages/cli/src/ui/components/messages/__snapshots__/ToolShared.test.tsx.snap @@ -2,7 +2,7 @@ exports[`MCPProgressIndicator > clamps progress exceeding total 1`] = ` " -████████████████████ 150%" +████████████████████ 100%" `; exports[`MCPProgressIndicator > renders determinate progress bar at 50% 1`] = ` diff --git a/packages/cli/src/ui/hooks/useMCPProgress.test.ts b/packages/cli/src/ui/hooks/useMCPProgress.test.ts deleted file mode 100644 index 4036bd00794..00000000000 --- a/packages/cli/src/ui/hooks/useMCPProgress.test.ts +++ /dev/null @@ -1,242 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import { renderHook } from '../../test-utils/render.js'; -import { act } from 'react'; -import { useMCPProgress } from './useMCPProgress.js'; -import { coreEvents, CoreEvent } from '@google/gemini-cli-core'; - -describe('useMCPProgress', () => { - beforeEach(() => { - vi.clearAllMocks(); - }); - - afterEach(() => { - // Clean up any lingering listeners - }); - - it('should initialize with empty state', () => { - const { result } = renderHook(() => useMCPProgress()); - expect(result.current.progressState).toEqual({}); - }); - - it('should update state when progress event is received', () => { - const { result } = renderHook(() => useMCPProgress()); - - act(() => { - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-1', - serverName: 'server', - toolName: 'tool', - progress: 50, - total: 100, - message: 'Processing...', - }); - }); - - expect(result.current.progressState).toEqual({ - 'call-1': { - progress: 50, - total: 100, - message: 'Processing...', - }, - }); - }); - - it('should track multiple concurrent tool calls', () => { - const { result } = renderHook(() => useMCPProgress()); - - act(() => { - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-1', - serverName: 'server1', - toolName: 'tool1', - progress: 25, - }); - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-2', - serverName: 'server2', - toolName: 'tool2', - progress: 75, - total: 100, - }); - }); - - expect(result.current.progressState).toEqual({ - 'call-1': { progress: 25, total: undefined, message: undefined }, - 'call-2': { progress: 75, total: 100, message: undefined }, - }); - }); - - it('should update existing progress for same callId', () => { - const { result } = renderHook(() => useMCPProgress()); - - act(() => { - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-1', - serverName: 'server', - toolName: 'tool', - progress: 25, - }); - }); - - act(() => { - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-1', - serverName: 'server', - toolName: 'tool', - progress: 75, - message: 'Almost done', - }); - }); - - expect(result.current.progressState['call-1']).toEqual({ - progress: 75, - total: undefined, - message: 'Almost done', - }); - }); - - it('should clear progress for specific callId', () => { - const { result } = renderHook(() => useMCPProgress()); - - act(() => { - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-1', - serverName: 'server', - toolName: 'tool', - progress: 50, - }); - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-2', - serverName: 'server', - toolName: 'tool', - progress: 75, - }); - }); - - act(() => { - result.current.clearProgress('call-1'); - }); - - expect(result.current.progressState).toEqual({ - 'call-2': { progress: 75, total: undefined, message: undefined }, - }); - }); - - it('should clear all progress', () => { - const { result } = renderHook(() => useMCPProgress()); - - act(() => { - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-1', - serverName: 'server', - toolName: 'tool', - progress: 50, - }); - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-2', - serverName: 'server', - toolName: 'tool', - progress: 75, - }); - }); - - act(() => { - result.current.clearAllProgress(); - }); - - expect(result.current.progressState).toEqual({}); - }); - - it('should reject late progress events for completed callIds', () => { - const { result } = renderHook(() => useMCPProgress()); - - // Emit progress for a call - act(() => { - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-1', - serverName: 'server', - toolName: 'tool', - progress: 50, - }); - }); - - // Clear progress (simulating tool reaching terminal state) - act(() => { - result.current.clearProgress('call-1'); - }); - - expect(result.current.progressState).toEqual({}); - - // Late progress event for same callId should be ignored - act(() => { - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-1', - serverName: 'server', - toolName: 'tool', - progress: 75, - }); - }); - - // Should still be empty — late event rejected - expect(result.current.progressState).toEqual({}); - }); - - it('should allow progress for previously-completed callId after clearAllProgress', () => { - const { result } = renderHook(() => useMCPProgress()); - - // Complete a call - act(() => { - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-1', - serverName: 'server', - toolName: 'tool', - progress: 100, - }); - result.current.clearProgress('call-1'); - }); - - // clearAllProgress resets completed tracking (new schedule) - act(() => { - result.current.clearAllProgress(); - }); - - // Same callId should now be accepted again (fresh schedule) - act(() => { - coreEvents.emit(CoreEvent.MCPToolProgress, { - callId: 'call-1', - serverName: 'server', - toolName: 'tool', - progress: 25, - }); - }); - - expect(result.current.progressState['call-1']).toEqual({ - progress: 25, - total: undefined, - message: undefined, - }); - }); - - it('should unsubscribe from events on unmount', () => { - const listenersBefore = coreEvents.listenerCount(CoreEvent.MCPToolProgress); - const { unmount } = renderHook(() => useMCPProgress()); - - // Hook should have added a listener - expect(coreEvents.listenerCount(CoreEvent.MCPToolProgress)).toBe( - listenersBefore + 1, - ); - - unmount(); - - // Listener should be removed after unmount - expect(coreEvents.listenerCount(CoreEvent.MCPToolProgress)).toBe( - listenersBefore, - ); - }); -}); diff --git a/packages/cli/src/ui/hooks/useMCPProgress.ts b/packages/cli/src/ui/hooks/useMCPProgress.ts deleted file mode 100644 index 17466749415..00000000000 --- a/packages/cli/src/ui/hooks/useMCPProgress.ts +++ /dev/null @@ -1,68 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { useState, useEffect, useCallback, useRef } from 'react'; -import { - coreEvents, - CoreEvent, - type MCPToolProgressPayload, -} from '@google/gemini-cli-core'; -import type { Progress } from '@modelcontextprotocol/sdk/types.js'; - -export interface MCPProgressState { - [callId: string]: Progress; -} - -/** - * Hook to track MCP tool progress updates. - * Subscribes to CoreEvent.MCPToolProgress and maintains state per callId. - */ -export const useMCPProgress = () => { - const [progressState, setProgressState] = useState({}); - // Track completed callIds to reject late-arriving progress events. - // Without this, a progress event arriving after clearProgress() would - // reintroduce a stale entry into progressState. - const completedCallIds = useRef>(new Set()); - - useEffect(() => { - const handleProgress = (payload: MCPToolProgressPayload) => { - // Reject progress for calls that have already reached a terminal state - if (completedCallIds.current.has(payload.callId)) { - return; - } - setProgressState((prev) => ({ - ...prev, - [payload.callId]: { - progress: payload.progress, - total: payload.total, - message: payload.message, - }, - })); - }; - - coreEvents.on(CoreEvent.MCPToolProgress, handleProgress); - return () => { - coreEvents.off(CoreEvent.MCPToolProgress, handleProgress); - }; - }, []); - - const clearProgress = useCallback((callId: string) => { - completedCallIds.current.add(callId); - setProgressState((prev) => { - if (!(callId in prev)) return prev; // No state to clear — avoid re-render - const { [callId]: _, ...rest } = prev; - return rest; - }); - }, []); - - const clearAllProgress = useCallback(() => { - // Reset completed tracking when starting fresh (new schedule) - completedCallIds.current.clear(); - setProgressState({}); - }, []); - - return { progressState, clearProgress, clearAllProgress }; -}; diff --git a/packages/cli/src/ui/hooks/useToolScheduler.ts b/packages/cli/src/ui/hooks/useToolScheduler.ts index d8d675f15a9..89bee143420 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.ts @@ -16,8 +16,6 @@ import { type ToolCallsUpdateMessage, } from '@google/gemini-cli-core'; import { useCallback, useState, useMemo, useEffect, useRef } from 'react'; -import type { Progress } from '@modelcontextprotocol/sdk/types.js'; -import { useMCPProgress } from './useMCPProgress.js'; // Re-exporting types compatible with hook expectations export type ScheduleFn = ( @@ -34,7 +32,6 @@ export type CancelAllFn = (signal: AbortSignal) => void; */ export type TrackedToolCall = ToolCall & { responseSubmittedToGemini?: boolean; - mcpProgress?: Progress; }; // Narrowed types for specific statuses (used by useGeminiStream) @@ -84,8 +81,6 @@ export function useToolScheduler( >({}); const [lastToolOutputTime, setLastToolOutputTime] = useState(0); - const { progressState, clearProgress, clearAllProgress } = useMCPProgress(); - const messageBus = useMemo(() => config.getMessageBus(), [config]); const onCompleteRef = useRef(onComplete); @@ -122,12 +117,6 @@ export function useToolScheduler( setLastToolOutputTime(Date.now()); } - event.toolCalls.forEach((tc) => { - if (['success', 'cancelled', 'error'].includes(tc.status)) { - clearProgress(tc.request.callId); - } - }); - setToolCallsMap((prev) => { const adapted = internalAdaptToolCalls( event.toolCalls, @@ -145,13 +134,12 @@ export function useToolScheduler( return () => { messageBus.unsubscribe(MessageBusType.TOOL_CALLS_UPDATE, handler); }; - }, [messageBus, internalAdaptToolCalls, clearProgress]); + }, [messageBus, internalAdaptToolCalls]); const schedule: ScheduleFn = useCallback( async (request, signal) => { // Clear state for new run setToolCallsMap({}); - clearAllProgress(); // 1. Await Core Scheduler directly const results = await scheduler.schedule(request, signal); @@ -163,7 +151,7 @@ export function useToolScheduler( return results; }, - [scheduler, clearAllProgress], + [scheduler], ); const cancelAll: CancelAllFn = useCallback( @@ -196,17 +184,6 @@ export function useToolScheduler( [toolCallsMap], ); - const toolCallsWithProgress = useMemo( - () => - toolCalls.map((tc): TrackedToolCall => { - if (tc.status === 'executing' && progressState[tc.request.callId]) { - return { ...tc, mcpProgress: progressState[tc.request.callId] }; - } - return tc; - }), - [toolCalls, progressState], - ); - // Provide a setter that maintains compatibility with legacy []. const setToolCallsForDisplay = useCallback( (action: React.SetStateAction) => { @@ -237,7 +214,7 @@ export function useToolScheduler( ); return [ - toolCallsWithProgress, + toolCalls, schedule, markToolsAsSubmitted, setToolCallsForDisplay, diff --git a/packages/core/src/agents/local-executor.test.ts b/packages/core/src/agents/local-executor.test.ts index d2634ecc520..aae0cb8fdff 100644 --- a/packages/core/src/agents/local-executor.test.ts +++ b/packages/core/src/agents/local-executor.test.ts @@ -21,6 +21,7 @@ import { DiscoveredMCPTool, MCP_QUALIFIED_NAME_SEPARATOR, } from '../tools/mcp-tool.js'; +import type { CallableToolWithProgress } from '../tools/mcp-client.js'; import { LSTool } from '../tools/ls.js'; import { LS_TOOL_NAME, READ_FILE_TOOL_NAME } from '../tools/tool-names.js'; import { @@ -35,7 +36,6 @@ import { type Content, type PartListUnion, type Tool, - type CallableTool, } from '@google/genai'; import type { Config } from '../config/config.js'; import { MockTool } from '../test-utils/mock-tool.js'; @@ -508,7 +508,7 @@ describe('LocalAgentExecutor', () => { const mockMcpTool = { tool: vi.fn(), callTool: vi.fn(), - } as unknown as CallableTool; + } as unknown as CallableToolWithProgress; const mcpTool = new DiscoveredMCPTool( mockMcpTool, diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 3c18b3daa2e..756fecffdd5 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -6,7 +6,7 @@ import { describe, it, expect, vi } from 'vitest'; import type { Mock } from 'vitest'; -import type { CallableTool } from '@google/genai'; +import type { CallableToolWithProgress } from '../tools/mcp-client.js'; import { CoreToolScheduler } from './coreToolScheduler.js'; import { type ToolCall, @@ -1934,7 +1934,7 @@ describe('CoreToolScheduler Sequential Execution', () => { const serverName = 'test-server'; const toolName = 'test-tool'; const mcpTool = new DiscoveredMCPTool( - mockMcpTool as unknown as CallableTool, + mockMcpTool as unknown as CallableToolWithProgress, serverName, toolName, 'description', diff --git a/packages/core/src/core/prompts.test.ts b/packages/core/src/core/prompts.test.ts index 12ab97cd589..949442ea601 100644 --- a/packages/core/src/core/prompts.test.ts +++ b/packages/core/src/core/prompts.test.ts @@ -25,7 +25,7 @@ import { } from '../config/models.js'; import { ApprovalMode } from '../policy/types.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; -import type { CallableTool } from '@google/genai'; +import type { CallableToolWithProgress } from '../tools/mcp-client.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; // Mock tool names if they are dynamically generated or complex @@ -442,7 +442,7 @@ describe('Core System Prompt (prompts.ts)', () => { vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN); const readOnlyMcpTool = new DiscoveredMCPTool( - {} as CallableTool, + {} as CallableToolWithProgress, 'readonly-server', 'read_static_value', 'A read-only tool', @@ -453,7 +453,7 @@ describe('Core System Prompt (prompts.ts)', () => { ); const nonReadOnlyMcpTool = new DiscoveredMCPTool( - {} as CallableTool, + {} as CallableToolWithProgress, 'nonreadonly-server', 'non_read_static_value', 'A non-read-only tool', diff --git a/packages/core/src/scheduler/scheduler.test.ts b/packages/core/src/scheduler/scheduler.test.ts index ad2d094b4ed..92397a11735 100644 --- a/packages/core/src/scheduler/scheduler.test.ts +++ b/packages/core/src/scheduler/scheduler.test.ts @@ -39,6 +39,7 @@ import { SchedulerStateManager, type TerminalCallHandler, } from './state-manager.js'; +import { coreEvents, CoreEvent } from '../utils/events.js'; import { resolveConfirmation } from './confirmation.js'; import { checkPolicy, updatePolicy } from './policy.js'; import { ToolExecutor } from './tool-executor.js'; @@ -177,6 +178,8 @@ describe('Scheduler (Orchestrator)', () => { setOutcome: vi.fn(), cancelAllQueued: vi.fn(), clearBatch: vi.fn(), + updateProgress: vi.fn(), + flushProgressThrottle: vi.fn(), } as unknown as Mocked; // Define getters for accessors idiomatically @@ -1242,4 +1245,127 @@ describe('Scheduler (Orchestrator)', () => { expect(capturedContext!.parentCallId).toBe(parentCallId); }); }); + + describe('MCPToolProgress integration', () => { + const setupExecution = () => { + const validatingCall: ValidatingToolCall = { + status: CoreToolCallStatus.Validating, + request: req1, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + Object.defineProperty(mockStateManager, 'queueLength', { + get: vi.fn().mockReturnValueOnce(1).mockReturnValue(0), + configurable: true, + }); + Object.defineProperty(mockStateManager, 'isActive', { + get: vi.fn().mockReturnValue(false), + configurable: true, + }); + vi.mocked(mockStateManager.dequeue).mockReturnValueOnce(validatingCall); + Object.defineProperty(mockStateManager, 'firstActiveCall', { + get: vi.fn().mockReturnValue(validatingCall), + configurable: true, + }); + }; + + it('should call state.updateProgress when MCPToolProgress fires for matching callId', async () => { + setupExecution(); + mockExecutor.execute.mockImplementation(async () => { + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: req1.callId, + serverName: 'test-server', + toolName: 'test-tool', + progress: 50, + total: 100, + message: 'halfway', + }); + return { + status: CoreToolCallStatus.Success, + } as unknown as SuccessfulToolCall; + }); + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.updateProgress).toHaveBeenCalledWith('call-1', { + progress: 50, + total: 100, + message: 'halfway', + }); + }); + + it('should ignore MCPToolProgress events for other callIds', async () => { + setupExecution(); + mockExecutor.execute.mockImplementation(async () => { + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: 'other-call-id', + serverName: 'test-server', + toolName: 'test-tool', + progress: 50, + total: 100, + }); + return { + status: CoreToolCallStatus.Success, + } as unknown as SuccessfulToolCall; + }); + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.updateProgress).not.toHaveBeenCalled(); + }); + + it('should clean up progress listener after execution', async () => { + setupExecution(); + mockExecutor.execute.mockResolvedValue({ + status: CoreToolCallStatus.Success, + } as unknown as SuccessfulToolCall); + + await scheduler.schedule(req1, signal); + + // Emit after execution - should NOT trigger updateProgress + mockStateManager.updateProgress.mockClear(); + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: req1.callId, + serverName: 'test-server', + toolName: 'test-tool', + progress: 99, + total: 100, + }); + + expect(mockStateManager.updateProgress).not.toHaveBeenCalled(); + }); + + it('should call flushProgressThrottle before terminal transition', async () => { + setupExecution(); + mockExecutor.execute.mockResolvedValue({ + status: CoreToolCallStatus.Success, + } as unknown as SuccessfulToolCall); + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.flushProgressThrottle).toHaveBeenCalled(); + }); + + it('should call flushProgressThrottle and clean up listener on throw path', async () => { + setupExecution(); + mockExecutor.execute.mockRejectedValue(new Error('execution failed')); + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.flushProgressThrottle).toHaveBeenCalled(); + + // Emit after failure - should NOT trigger updateProgress + mockStateManager.updateProgress.mockClear(); + coreEvents.emit(CoreEvent.MCPToolProgress, { + callId: req1.callId, + serverName: 'test-server', + toolName: 'test-tool', + progress: 99, + total: 100, + }); + + expect(mockStateManager.updateProgress).not.toHaveBeenCalled(); + }); + }); }); diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts index b177fe0318b..dc416e4241f 100644 --- a/packages/core/src/scheduler/scheduler.ts +++ b/packages/core/src/scheduler/scheduler.ts @@ -39,6 +39,11 @@ import { type ToolConfirmationRequest, } from '../confirmation-bus/types.js'; import { runWithToolCallContext } from '../utils/toolCallContext.js'; +import { + coreEvents, + CoreEvent, + type MCPToolProgressPayload, +} from '../utils/events.js'; interface SchedulerQueueItem { requests: ToolCallRequestInfo[]; @@ -504,51 +509,69 @@ export class Scheduler { // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const activeCall = this.state.firstActiveCall as ExecutingToolCall; - const result = await runWithToolCallContext( - { - callId: activeCall.request.callId, - schedulerId: this.schedulerId, - parentCallId: this.parentCallId, - }, - () => - this.executor.execute({ - call: activeCall, - signal, - outputUpdateHandler: (id, out) => - this.state.updateStatus(id, CoreToolCallStatus.Executing, { - liveOutput: out, - }), - onUpdateToolCall: (updated) => { - if ( - updated.status === CoreToolCallStatus.Executing && - updated.pid - ) { - this.state.updateStatus(callId, CoreToolCallStatus.Executing, { - pid: updated.pid, - }); - } - }, - }), - ); + const progressHandler = (payload: MCPToolProgressPayload) => { + if (payload.callId !== callId) return; + this.state.updateProgress(callId, { + progress: payload.progress, + total: payload.total, + message: payload.message, + }); + }; + coreEvents.on(CoreEvent.MCPToolProgress, progressHandler); - if (result.status === CoreToolCallStatus.Success) { - this.state.updateStatus( - callId, - CoreToolCallStatus.Success, - result.response, - ); - } else if (result.status === CoreToolCallStatus.Cancelled) { - this.state.updateStatus( - callId, - CoreToolCallStatus.Cancelled, - 'Operation cancelled', - ); - } else { - this.state.updateStatus( - callId, - CoreToolCallStatus.Error, - result.response, + try { + const result = await runWithToolCallContext( + { + callId: activeCall.request.callId, + schedulerId: this.schedulerId, + parentCallId: this.parentCallId, + }, + () => + this.executor.execute({ + call: activeCall, + signal, + outputUpdateHandler: (id, out) => + this.state.updateStatus(id, CoreToolCallStatus.Executing, { + liveOutput: out, + }), + onUpdateToolCall: (updated) => { + if ( + updated.status === CoreToolCallStatus.Executing && + updated.pid + ) { + this.state.updateStatus(callId, CoreToolCallStatus.Executing, { + pid: updated.pid, + }); + } + }, + }), ); + + coreEvents.off(CoreEvent.MCPToolProgress, progressHandler); + this.state.flushProgressThrottle(); + + if (result.status === CoreToolCallStatus.Success) { + this.state.updateStatus( + callId, + CoreToolCallStatus.Success, + result.response, + ); + } else if (result.status === CoreToolCallStatus.Cancelled) { + this.state.updateStatus( + callId, + CoreToolCallStatus.Cancelled, + 'Operation cancelled', + ); + } else { + this.state.updateStatus( + callId, + CoreToolCallStatus.Error, + result.response, + ); + } + } finally { + this.state.flushProgressThrottle(); + coreEvents.off(CoreEvent.MCPToolProgress, progressHandler); } } diff --git a/packages/core/src/scheduler/state-manager.test.ts b/packages/core/src/scheduler/state-manager.test.ts index 758ff354c01..3f306a9c207 100644 --- a/packages/core/src/scheduler/state-manager.test.ts +++ b/packages/core/src/scheduler/state-manager.test.ts @@ -650,4 +650,123 @@ describe('SchedulerStateManager', () => { expect(snapshot[2].request.callId).toBe('3'); }); }); + + describe('updateProgress', () => { + function makeExecuting(id = 'call-1') { + const call = createValidatingCall(id); + stateManager.enqueue([call]); + stateManager.dequeue(); + stateManager.updateStatus(id, CoreToolCallStatus.Executing); + vi.mocked(mockMessageBus.publish).mockClear(); + } + + it('should update mcpProgress on an executing call', () => { + makeExecuting(); + stateManager.updateProgress('call-1', { progress: 50, total: 100 }); + + const snapshot = stateManager.getSnapshot(); + const active = snapshot[0] as ExecutingToolCall; + expect(active.mcpProgress).toEqual({ progress: 50, total: 100 }); + }); + + it('should ignore non-executing calls', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + stateManager.updateStatus( + 'call-1', + CoreToolCallStatus.Success, + createMockResponse('call-1'), + ); + vi.mocked(mockMessageBus.publish).mockClear(); + + stateManager.updateProgress('call-1', { progress: 50, total: 100 }); + expect(mockMessageBus.publish).not.toHaveBeenCalled(); + }); + + it('should ignore unknown callIds', () => { + stateManager.updateProgress('nonexistent', { progress: 50, total: 100 }); + expect(mockMessageBus.publish).not.toHaveBeenCalled(); + }); + + it('should preserve mcpProgress across liveOutput updates', () => { + makeExecuting(); + stateManager.updateProgress('call-1', { progress: 50, total: 100 }); + stateManager.updateStatus('call-1', CoreToolCallStatus.Executing, { + liveOutput: 'new output', + }); + + const active = stateManager.getSnapshot()[0] as ExecutingToolCall; + expect(active.mcpProgress).toEqual({ progress: 50, total: 100 }); + expect(active.liveOutput).toBe('new output'); + }); + + it('should preserve liveOutput across progress updates', () => { + makeExecuting(); + stateManager.updateStatus('call-1', CoreToolCallStatus.Executing, { + liveOutput: 'existing output', + }); + stateManager.updateProgress('call-1', { progress: 75, total: 100 }); + + const active = stateManager.getSnapshot()[0] as ExecutingToolCall; + expect(active.liveOutput).toBe('existing output'); + expect(active.mcpProgress).toEqual({ progress: 75, total: 100 }); + }); + }); + + describe('progress throttling', () => { + function makeExecuting(id = 'call-1') { + const call = createValidatingCall(id); + stateManager.enqueue([call]); + stateManager.dequeue(); + stateManager.updateStatus(id, CoreToolCallStatus.Executing); + vi.mocked(mockMessageBus.publish).mockClear(); + } + + it('should emit immediately on first progress update', () => { + makeExecuting(); + stateManager.updateProgress('call-1', { progress: 10, total: 100 }); + expect(mockMessageBus.publish).toHaveBeenCalledTimes(1); + }); + + it('should batch rapid progress updates with leading+trailing', () => { + vi.useFakeTimers(); + makeExecuting(); + + for (let i = 1; i <= 10; i++) { + stateManager.updateProgress('call-1', { progress: i * 10, total: 100 }); + } + + // Leading emit on first call + expect(mockMessageBus.publish).toHaveBeenCalledTimes(1); + + // Advance past throttle window + vi.advanceTimersByTime(100); + + // Trailing emit + expect(mockMessageBus.publish).toHaveBeenCalledTimes(2); + + vi.useRealTimers(); + }); + + it('should flush pending progress update via flushProgressThrottle', () => { + vi.useFakeTimers(); + makeExecuting(); + + stateManager.updateProgress('call-1', { progress: 10, total: 100 }); + stateManager.updateProgress('call-1', { progress: 20, total: 100 }); + expect(mockMessageBus.publish).toHaveBeenCalledTimes(1); + + stateManager.flushProgressThrottle(); + expect(mockMessageBus.publish).toHaveBeenCalledTimes(2); + + vi.useRealTimers(); + }); + + it('should be a no-op when flushProgressThrottle has nothing pending', () => { + makeExecuting(); + stateManager.flushProgressThrottle(); + expect(mockMessageBus.publish).not.toHaveBeenCalled(); + }); + }); }); diff --git a/packages/core/src/scheduler/state-manager.ts b/packages/core/src/scheduler/state-manager.ts index 6a473ad47cc..27a004ba507 100644 --- a/packages/core/src/scheduler/state-manager.ts +++ b/packages/core/src/scheduler/state-manager.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { Progress } from '@modelcontextprotocol/sdk/types.js'; import type { ToolCall, Status, @@ -45,6 +46,8 @@ export class SchedulerStateManager { private readonly activeCalls = new Map(); private readonly queue: ToolCall[] = []; private _completedBatch: CompletedToolCall[] = []; + private progressThrottleTimer: ReturnType | null = null; + private hasPendingProgressUpdate = false; constructor( private readonly messageBus: MessageBus, @@ -52,6 +55,38 @@ export class SchedulerStateManager { private readonly onTerminalCall?: TerminalCallHandler, ) {} + updateProgress(callId: string, progress: Progress): void { + const call = this.activeCalls.get(callId); + if (!call || call.status !== CoreToolCallStatus.Executing) return; + + const updated = this.toExecuting(call, { mcpProgress: progress }); + this.activeCalls.set(callId, updated); + + if (!this.progressThrottleTimer) { + this.emitUpdate(); + this.progressThrottleTimer = setTimeout(() => { + this.progressThrottleTimer = null; + if (this.hasPendingProgressUpdate) { + this.hasPendingProgressUpdate = false; + this.emitUpdate(); + } + }, 100); + } else { + this.hasPendingProgressUpdate = true; + } + } + + flushProgressThrottle(): void { + if (this.progressThrottleTimer) { + clearTimeout(this.progressThrottleTimer); + this.progressThrottleTimer = null; + if (this.hasPendingProgressUpdate) { + this.hasPendingProgressUpdate = false; + this.emitUpdate(); + } + } + } + addToolCalls(calls: ToolCall[]): void { this.enqueue(calls); } @@ -517,6 +552,9 @@ export class SchedulerStateManager { execData?.liveOutput ?? ('liveOutput' in call ? call.liveOutput : undefined); const pid = execData?.pid ?? ('pid' in call ? call.pid : undefined); + const mcpProgress = + execData?.mcpProgress ?? + ('mcpProgress' in call ? call.mcpProgress : undefined); return { request: call.request, @@ -527,6 +565,7 @@ export class SchedulerStateManager { invocation: call.invocation, liveOutput, pid, + mcpProgress, schedulerId: call.schedulerId, approvalMode: call.approvalMode, }; diff --git a/packages/core/src/telemetry/loggers.test.ts b/packages/core/src/telemetry/loggers.test.ts index 316cf0b33f5..bd592f3b396 100644 --- a/packages/core/src/telemetry/loggers.test.ts +++ b/packages/core/src/telemetry/loggers.test.ts @@ -103,10 +103,10 @@ import { vi, describe, beforeEach, it, expect, afterEach } from 'vitest'; import { type GeminiCLIExtension } from '../config/config.js'; import { FinishReason, - type CallableTool, type GenerateContentResponseUsageMetadata, } from '@google/genai'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; +import type { CallableToolWithProgress } from '../tools/mcp-client.js'; import * as uiTelemetry from './uiTelemetry.js'; import { makeFakeConfig } from '../test-utils/config.js'; import { ClearcutLogger } from './clearcut-logger/clearcut-logger.js'; @@ -1621,7 +1621,7 @@ describe('loggers', () => { it('should log a tool call with mcp_server_name for MCP tools', () => { const mockMcpTool = new DiscoveredMCPTool( - {} as CallableTool, + {} as CallableToolWithProgress, 'mock_mcp_server', 'mock_mcp_tool', 'tool description', diff --git a/packages/core/src/tools/mcp-tool.test.ts b/packages/core/src/tools/mcp-tool.test.ts index fc1cfcc8906..d9a3fb08b2a 100644 --- a/packages/core/src/tools/mcp-tool.test.ts +++ b/packages/core/src/tools/mcp-tool.test.ts @@ -15,7 +15,8 @@ import { } from './mcp-tool.js'; import type { ToolResult } from './tools.js'; import { ToolConfirmationOutcome } from './tools.js'; // Added ToolConfirmationOutcome -import type { CallableTool, FunctionCall, Part } from '@google/genai'; +import type { FunctionCall, Part } from '@google/genai'; +import type { CallableToolWithProgress } from './mcp-client.js'; import type { Progress } from '@modelcontextprotocol/sdk/types.js'; import { ToolErrorType } from './tool-error.js'; import { @@ -24,12 +25,12 @@ import { } from '../test-utils/mock-message-bus.js'; import { coreEvents, CoreEvent } from '../utils/events.js'; -// Mock @google/genai mcpToTool and CallableTool -// We only need to mock the parts of CallableTool that DiscoveredMCPTool uses. +// Mock CallableToolWithProgress +// We only need to mock the parts that DiscoveredMCPTool uses. const mockCallTool = vi.fn(); const mockToolMethod = vi.fn(); -const mockCallableToolInstance: Mocked = { +const mockCallableToolInstance: Mocked = { tool: mockToolMethod as any, // Not directly used by DiscoveredMCPTool instance methods callTool: mockCallTool as any, // Add other methods if DiscoveredMCPTool starts using them @@ -952,7 +953,7 @@ describe('DiscoveredMCPTool', () => { const mockCallableToolInst = { tool: vi.fn(), callTool: progressMockCallTool, - } as unknown as Mocked; + } as unknown as Mocked; const testTool = new DiscoveredMCPTool( mockCallableToolInst, diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index 6876cc528ec..45f6a902862 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -18,7 +18,7 @@ import { ToolConfirmationOutcome, type PolicyUpdateOptions, } from './tools.js'; -import type { CallableTool, FunctionCall, Part } from '@google/genai'; +import type { FunctionCall, Part } from '@google/genai'; import type { Progress } from '@modelcontextprotocol/sdk/types.js'; import { ToolErrorType } from './tool-error.js'; import type { Config } from '../config/config.js'; @@ -84,7 +84,7 @@ export class DiscoveredMCPToolInvocation extends BaseToolInvocation< } constructor( - private readonly mcpTool: CallableTool, + private readonly mcpTool: CallableToolWithProgress, readonly serverName: string, readonly serverToolName: string, readonly displayName: string, @@ -221,10 +221,7 @@ export class DiscoveredMCPToolInvocation extends BaseToolInvocation< // Conditionally pass progressCallback to avoid passing undefined as // a second arg (which would change the call signature for existing tests) const callPromise = progressCallback - ? (this.mcpTool as CallableToolWithProgress).callTool( - functionCalls, - progressCallback, - ) + ? this.mcpTool.callTool(functionCalls, progressCallback) : this.mcpTool.callTool(functionCalls); callPromise @@ -273,7 +270,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool< ToolResult > { constructor( - private readonly mcpTool: CallableTool, + private readonly mcpTool: CallableToolWithProgress, readonly serverName: string, readonly serverToolName: string, description: string, diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index 963830200df..d12979287e7 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -14,7 +14,8 @@ import { ApprovalMode } from '../policy/types.js'; import { ToolRegistry, DiscoveredTool } from './tool-registry.js'; import { DISCOVERED_TOOL_PREFIX } from './tool-names.js'; import { DiscoveredMCPTool, MCP_QUALIFIED_NAME_SEPARATOR } from './mcp-tool.js'; -import type { FunctionDeclaration, CallableTool } from '@google/genai'; +import type { FunctionDeclaration } from '@google/genai'; +import type { CallableToolWithProgress } from './mcp-client.js'; import { mcpToTool } from '@google/genai'; import { spawn } from 'node:child_process'; @@ -106,10 +107,9 @@ vi.mock('./tool-names.js', async (importOriginal) => { }; }); -// Helper to create a mock CallableTool for specific test needs const createMockCallableTool = ( toolDeclarations: FunctionDeclaration[], -): Mocked => ({ +): Mocked => ({ tool: vi.fn().mockResolvedValue({ functionDeclarations: toolDeclarations }), callTool: vi.fn(), }); @@ -125,7 +125,7 @@ const createMCPTool = ( serverName: string, toolName: string, description: string, - mockCallable: CallableTool = {} as CallableTool, + mockCallable: CallableToolWithProgress = {} as CallableToolWithProgress, ) => new DiscoveredMCPTool( mockCallable,