From 33fec42c8d355e272a4766a50be4f720e354a1a9 Mon Sep 17 00:00:00 2001 From: Abhi Date: Sun, 7 Sep 2025 19:45:34 -0400 Subject: [PATCH 1/3] refactor - init llm utility service --- packages/core/src/config/config.test.ts | 61 ++++ packages/core/src/config/config.ts | 28 ++ .../core/src/core/llmUtilityService.test.ts | 279 ++++++++++++++++++ packages/core/src/core/llmUtilityService.ts | 152 ++++++++++ 4 files changed, 520 insertions(+) create mode 100644 packages/core/src/core/llmUtilityService.test.ts create mode 100644 packages/core/src/core/llmUtilityService.ts diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 13e051c0b92..ecdf7c7b03f 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -122,6 +122,10 @@ vi.mock('../ide/ide-client.js', () => ({ }, })); +import { LlmUtilityService } from '../core/llmUtilityService.js'; + +vi.mock('../core/llmUtilityService.js'); + describe('Server Config (config.ts)', () => { const MODEL = 'gemini-pro'; const SANDBOX: SandboxConfig = { @@ -774,3 +778,60 @@ describe('setApprovalMode with folder trust', () => { }); }); }); + +describe('LlmUtilityService Lifecycle', () => { + const MODEL = 'gemini-pro'; + const SANDBOX: SandboxConfig = { + command: 'docker', + image: 'gemini-cli-sandbox', + }; + const TARGET_DIR = '/path/to/target'; + const DEBUG_MODE = false; + const QUESTION = 'test question'; + const FULL_CONTEXT = false; + const USER_MEMORY = 'Test User Memory'; + const TELEMETRY_SETTINGS = { enabled: false }; + const EMBEDDING_MODEL = 'gemini-embedding'; + const SESSION_ID = 'test-session-id'; + const baseParams: ConfigParameters = { + cwd: '/tmp', + embeddingModel: EMBEDDING_MODEL, + sandbox: SANDBOX, + targetDir: TARGET_DIR, + debugMode: DEBUG_MODE, + question: QUESTION, + fullContext: FULL_CONTEXT, + userMemory: USER_MEMORY, + telemetry: TELEMETRY_SETTINGS, + sessionId: SESSION_ID, + model: MODEL, + usageStatisticsEnabled: false, + }; + + it('should throw an error if getLlmUtilityService is called before refreshAuth', () => { + const config = new Config(baseParams); + expect(() => config.getLlmUtilityService()).toThrow( + 'LlmUtilityService not initialized. Ensure authentication has occurred and GeminiClient is ready.', + ); + }); + + it('should successfully initialize LlmUtilityService after refreshAuth is called', async () => { + const config = new Config(baseParams); + const authType = AuthType.USE_GEMINI; + const mockContentConfig = { model: 'gemini-flash', apiKey: 'test-key' }; + + vi.mocked(createContentGeneratorConfig).mockReturnValue( + mockContentConfig, + ); + + await config.refreshAuth(authType); + + // Should not throw + const llmService = config.getLlmUtilityService(); + expect(llmService).toBeDefined(); + expect(LlmUtilityService).toHaveBeenCalledWith( + config.getContentGenerator(), + config, + ); + }); + }); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 746f1051b26..a022df8e687 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -31,6 +31,7 @@ import { ReadManyFilesTool } from '../tools/read-many-files.js'; import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js'; import { WebSearchTool } from '../tools/web-search.js'; import { GeminiClient } from '../core/client.js'; +import { LlmUtilityService } from '../core/llmUtilityService.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { GitService } from '../services/gitService.js'; import type { TelemetryTarget } from '../telemetry/index.js'; @@ -257,6 +258,7 @@ export class Config { private readonly telemetrySettings: TelemetrySettings; private readonly usageStatisticsEnabled: boolean; private geminiClient!: GeminiClient; + private llmUtilityService!: LlmUtilityService; private readonly fileFiltering: { respectGitIgnore: boolean; respectGeminiIgnore: boolean; @@ -455,6 +457,12 @@ export class Config { // Only assign to instance properties after successful initialization this.contentGeneratorConfig = newContentGeneratorConfig; + // Initialize LlmUtilityService now that the ContentGenerator is available + this.llmUtilityService = new LlmUtilityService( + this.contentGenerator, + this, + ); + // Reset the session flag since we're explicitly changing auth and using default model this.inFallbackMode = false; } @@ -463,6 +471,26 @@ export class Config { return this.contentGenerator?.userTier; } + /** + * Provides access to the LlmUtilityService for stateless LLM operations. + */ + getLlmUtilityService(): LlmUtilityService { + if (!this.llmUtilityService) { + // Handle cases where initialization might be deferred or authentication failed + if (this.contentGenerator) { + this.llmUtilityService = new LlmUtilityService( + this.getContentGenerator(), + this, + ); + } else { + throw new Error( + 'LlmUtilityService not initialized. Ensure authentication has occurred and GeminiClient is ready.', + ); + } + } + return this.llmUtilityService; + } + getSessionId(): string { return this.sessionId; } diff --git a/packages/core/src/core/llmUtilityService.test.ts b/packages/core/src/core/llmUtilityService.test.ts new file mode 100644 index 00000000000..390f5fbd1bb --- /dev/null +++ b/packages/core/src/core/llmUtilityService.test.ts @@ -0,0 +1,279 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mocked, +} from 'vitest'; + +import type { GenerateContentResponse } from '@google/genai'; +import { LlmUtilityService, type GenerateJsonOptions } from './llmUtilityService.js'; +import type { ContentGenerator } from './contentGenerator.js'; +import type { Config } from '../config/config.js'; +import { AuthType } from './contentGenerator.js'; +import { reportError } from '../utils/errorReporting.js'; +import { logMalformedJsonResponse } from '../telemetry/loggers.js'; +import { retryWithBackoff } from '../utils/retry.js'; +import { MalformedJsonResponseEvent } from '../telemetry/types.js'; +import { getErrorMessage } from '../utils/errors.js'; + + +vi.mock('../utils/errorReporting.js'); +vi.mock('../telemetry/loggers.js'); +vi.mock('../utils/errors.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + getErrorMessage: vi.fn((e) => (e instanceof Error ? e.message : String(e))), + }; +}); + +vi.mock('../utils/retry.js', () => ({ + retryWithBackoff: vi.fn(async (fn) => await fn()), +})); + +const mockGenerateContent = vi.fn(); + +const mockContentGenerator = { + generateContent: mockGenerateContent, +} as unknown as Mocked; + +const mockConfig = { + getSessionId: vi.fn().mockReturnValue('test-session-id'), + getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: AuthType.USE_GEMINI }), +} as unknown as Mocked; + +// Helper to create a mock GenerateContentResponse +const createMockResponse = (text: string): GenerateContentResponse => ({ + candidates: [{ content: { role: 'model', parts: [{ text }] }, index: 0 }], +} as GenerateContentResponse); + +describe('LlmUtilityService', () => { + let service: LlmUtilityService; + let abortController: AbortController; + let defaultOptions: GenerateJsonOptions; + + beforeEach(() => { + vi.clearAllMocks(); + // Reset the mocked implementation for getErrorMessage for accurate error message assertions + vi.mocked(getErrorMessage).mockImplementation((e) => (e instanceof Error ? e.message : String(e))); + service = new LlmUtilityService(mockContentGenerator, mockConfig); + abortController = new AbortController(); + defaultOptions = { + contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }], + schema: { type: 'object', properties: { color: { type: 'string' } } }, + model: 'test-model', + abortSignal: abortController.signal, + promptId: 'test-prompt-id', + }; + }); + + afterEach(() => { + abortController.abort(); + }); + + describe('generateJson - Success Scenarios', () => { + it('should call generateContent with correct parameters, defaults, and utilize retry mechanism', async () => { + const mockResponse = createMockResponse('{"color": "blue"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + + const result = await service.generateJson(defaultOptions); + + expect(result).toEqual({ color: 'blue' }); + + // Ensure the retry mechanism was engaged + expect(retryWithBackoff).toHaveBeenCalledTimes(1); + + // Validate the parameters passed to the underlying generator + expect(mockGenerateContent).toHaveBeenCalledTimes(1); + expect(mockGenerateContent).toHaveBeenCalledWith( + { + model: 'test-model', + contents: defaultOptions.contents, + config: { + abortSignal: defaultOptions.abortSignal, + temperature: 0, + topP: 1, + responseJsonSchema: defaultOptions.schema, + responseMimeType: 'application/json', + // Crucial: systemInstruction should NOT be in the config object if not provided + }, + }, + 'test-prompt-id', + ); + }); + + it('should respect configuration overrides', async () => { + const mockResponse = createMockResponse('{"color": "red"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + + const options: GenerateJsonOptions = { + ...defaultOptions, + config: { temperature: 0.8, topK: 10 }, + }; + + await service.generateJson(options); + + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.objectContaining({ + config: expect.objectContaining({ + temperature: 0.8, + topP: 1, // Default should remain if not overridden + topK: 10, + }), + }), + expect.any(String), + ); + }); + + it('should include system instructions when provided', async () => { + const mockResponse = createMockResponse('{"color": "green"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + const systemInstruction = 'You are a helpful assistant.'; + + const options: GenerateJsonOptions = { + ...defaultOptions, + systemInstruction, + }; + + await service.generateJson(options); + + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.objectContaining({ + config: expect.objectContaining({ + systemInstruction, + }), + }), + expect.any(String), + ); + }); + + it('should use the provided promptId', async () => { + const mockResponse = createMockResponse('{"color": "yellow"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + const customPromptId = 'custom-id-123'; + + const options: GenerateJsonOptions = { + ...defaultOptions, + promptId: customPromptId, + }; + + await service.generateJson(options); + + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.any(Object), + customPromptId, + ); + }); + }); + + describe('generateJson - Response Cleaning', () => { + it('should clean JSON wrapped in markdown backticks and log telemetry', async () => { + const malformedResponse = '```json\n{"color": "purple"}\n```'; + mockGenerateContent.mockResolvedValue(createMockResponse(malformedResponse)); + + const result = await service.generateJson(defaultOptions); + + expect(result).toEqual({ color: 'purple' }); + expect(logMalformedJsonResponse).toHaveBeenCalledTimes(1); + expect(logMalformedJsonResponse).toHaveBeenCalledWith( + mockConfig, + expect.any(MalformedJsonResponseEvent) + ); + // Validate the telemetry event content + const event = vi.mocked(logMalformedJsonResponse).mock.calls[0][1] as MalformedJsonResponseEvent; + expect(event.model).toBe('test-model'); + }); + + it('should handle extra whitespace correctly without logging malformed telemetry', async () => { + const responseWithWhitespace = ' \n {"color": "orange"} \n'; + mockGenerateContent.mockResolvedValue(createMockResponse(responseWithWhitespace)); + + const result = await service.generateJson(defaultOptions); + + expect(result).toEqual({ color: 'orange' }); + expect(logMalformedJsonResponse).not.toHaveBeenCalled(); + }); + }); + + describe('generateJson - Error Handling', () => { + it('should throw and report error for empty response', async () => { + mockGenerateContent.mockResolvedValue(createMockResponse('')); + + // The final error message includes the prefix added by the service's outer catch block. + await expect(service.generateJson(defaultOptions)).rejects.toThrow( + 'Failed to generate JSON content: API returned an empty response for generateJson.' + ); + + // Verify error reporting details + expect(reportError).toHaveBeenCalledTimes(1); + expect(reportError).toHaveBeenCalledWith( + expect.any(Error), + 'Error in generateJson: API returned an empty response.', + defaultOptions.contents, + 'generateJson-empty-response' + ); + }); + + it('should throw and report error for invalid JSON syntax', async () => { + const invalidJson = '{"color": "blue"'; // missing closing brace + mockGenerateContent.mockResolvedValue(createMockResponse(invalidJson)); + + await expect(service.generateJson(defaultOptions)).rejects.toThrow( + /^Failed to generate JSON content: Failed to parse API response as JSON:/ + ); + + expect(reportError).toHaveBeenCalledTimes(1); + expect(reportError).toHaveBeenCalledWith( + expect.any(Error), + 'Failed to parse JSON response from generateJson.', + expect.objectContaining({ responseTextFailedToParse: invalidJson }), + 'generateJson-parse' + ); + }); + + it('should throw and report generic API errors', async () => { + const apiError = new Error('Service Unavailable (503)'); + // Simulate the generator failing + mockGenerateContent.mockRejectedValue(apiError); + + await expect(service.generateJson(defaultOptions)).rejects.toThrow( + 'Failed to generate JSON content: Service Unavailable (503)' + ); + + // Verify generic error reporting + expect(reportError).toHaveBeenCalledTimes(1); + expect(reportError).toHaveBeenCalledWith( + apiError, + 'Error generating JSON content via API.', + defaultOptions.contents, + 'generateJson-api' + ); + }); + + it('should throw immediately without reporting if aborted', async () => { + const abortError = new DOMException('Aborted', 'AbortError'); + + // Simulate abortion happening during the API call + mockGenerateContent.mockImplementation(() => { + abortController.abort(); // Ensure the signal is aborted when the service checks + throw abortError; + }); + + const options = { ...defaultOptions, abortSignal: abortController.signal }; + + await expect(service.generateJson(options)).rejects.toThrow(abortError); + + // Crucially, it should not report a cancellation as an application error + expect(reportError).not.toHaveBeenCalled(); + }); + }); +}); \ No newline at end of file diff --git a/packages/core/src/core/llmUtilityService.ts b/packages/core/src/core/llmUtilityService.ts new file mode 100644 index 00000000000..8dcb50d09e1 --- /dev/null +++ b/packages/core/src/core/llmUtilityService.ts @@ -0,0 +1,152 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + Content, + GenerateContentConfig, + Part, +} from '@google/genai'; +import type { Config } from '../config/config.js'; +import type { ContentGenerator } from './contentGenerator.js'; +import { getResponseText } from '../utils/partUtils.js'; +import { reportError } from '../utils/errorReporting.js'; +import { getErrorMessage } from '../utils/errors.js'; +import { logMalformedJsonResponse } from '../telemetry/loggers.js'; +import { MalformedJsonResponseEvent } from '../telemetry/types.js'; +import { retryWithBackoff } from '../utils/retry.js'; + +/** + * Options for the generateJson utility function. + */ +export interface GenerateJsonOptions { + /** The input prompt or history. */ + contents: Content[]; + /** The required JSON schema for the output. */ + schema: Record; + /** The specific model to use for this task. */ + model: string; + /** + * Task-specific system instructions. + * If omitted, no system instruction is sent. + */ + systemInstruction?: string | Part | Part[] | Content; + /** + * Overrides for generation configuration (e.g., temperature). + */ + config?: Omit; + /** Signal for cancellation. */ + abortSignal: AbortSignal; + /** + * A unique ID for the prompt, used for logging/telemetry correlation. + */ + promptId: string; +} + +/** + * A service dedicated to stateless, utility-focused LLM calls. + */ +export class LlmUtilityService { + // Default configuration for utility tasks + private readonly defaultUtilityConfig: GenerateContentConfig = { + temperature: 0, + topP: 1, + }; + + constructor( + private readonly contentGenerator: ContentGenerator, + private readonly config: Config, + ) {} + + async generateJson(options: GenerateJsonOptions): Promise> { + const { contents, schema, model, abortSignal, systemInstruction, promptId } = options; + + const requestConfig: GenerateContentConfig = { + abortSignal, + ...this.defaultUtilityConfig, + ...options.config, + ...(systemInstruction && { systemInstruction }), + responseJsonSchema: schema, + responseMimeType: 'application/json', + }; + + try { + const apiCall = () => this.contentGenerator.generateContent( + { + model, + config: requestConfig, + contents, + }, + promptId, + ); + + const result = await retryWithBackoff(apiCall); + + let text = getResponseText(result)?.trim(); + if (!text) { + const error = new Error('API returned an empty response for generateJson.'); + await reportError( + error, + 'Error in generateJson: API returned an empty response.', + contents, + 'generateJson-empty-response', + ); + throw error; + } + + text = this.cleanJsonResponse(text, model); + + try { + return JSON.parse(text); + } catch (parseError) { + const error = new Error(`Failed to parse API response as JSON: ${getErrorMessage(parseError)}`); + await reportError( + parseError, + 'Failed to parse JSON response from generateJson.', + { + responseTextFailedToParse: text, + originalRequestContents: contents, + }, + 'generateJson-parse', + ); + throw error; + } + + } catch (error) { + if (abortSignal.aborted) { + throw error; + } + + if (error instanceof Error && ( + error.message === 'API returned an empty response for generateJson.' || + error.message.startsWith('Failed to parse API response as JSON:') + )) { + // We perform this check so that we don't report these again. + } else { + await reportError( + error, + 'Error generating JSON content via API.', + contents, + 'generateJson-api', + ); + } + + throw new Error(`Failed to generate JSON content: ${getErrorMessage(error)}`); + } + } + + private cleanJsonResponse(text: string, model: string): string { + const prefix = '```json'; + const suffix = '```'; + if (text.startsWith(prefix) && text.endsWith(suffix)) { + logMalformedJsonResponse( + this.config, + new MalformedJsonResponseEvent(model), + ); + return text.substring(prefix.length, text.length - suffix.length).trim(); + } + return text; + } +} From 44ee4b06915ec829a300cb0bad616c9f49adaff0 Mon Sep 17 00:00:00 2001 From: Abhi Date: Sun, 7 Sep 2025 23:16:26 -0400 Subject: [PATCH 2/3] Refactor LLM JSON Generation calls --- packages/cli/src/nonInteractiveCli.ts | 195 ++++++++--------- packages/cli/src/ui/hooks/useGeminiStream.ts | 110 +++++----- packages/core/src/config/config.test.ts | 98 +++++---- packages/core/src/config/config.ts | 5 +- .../core/src/core/llmUtilityService.test.ts | 145 +++++++------ packages/core/src/core/llmUtilityService.ts | 101 +++++---- packages/core/src/index.ts | 1 + packages/core/src/tools/smart-edit.test.ts | 7 + packages/core/src/tools/smart-edit.ts | 2 +- .../core/src/utils/llm-edit-fixer.test.ts | 204 ++++++++++++++++++ packages/core/src/utils/llm-edit-fixer.ts | 36 ++-- packages/core/src/utils/promptIdContext.ts | 13 ++ 12 files changed, 593 insertions(+), 324 deletions(-) create mode 100644 packages/core/src/utils/llm-edit-fixer.test.ts create mode 100644 packages/core/src/utils/promptIdContext.ts diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index 73e8ae23711..4adcd93f305 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -13,6 +13,7 @@ import { parseAndFormatApiError, FatalInputError, FatalTurnLimitedError, + promptIdContext, } from '@google/gemini-cli-core'; import type { Content, Part } from '@google/genai'; @@ -24,115 +25,117 @@ export async function runNonInteractive( input: string, prompt_id: string, ): Promise { - const consolePatcher = new ConsolePatcher({ - stderr: true, - debugMode: config.getDebugMode(), - }); - - try { - consolePatcher.patch(); - // Handle EPIPE errors when the output is piped to a command that closes early. - process.stdout.on('error', (err: NodeJS.ErrnoException) => { - if (err.code === 'EPIPE') { - // Exit gracefully if the pipe is closed. - process.exit(0); - } - }); - - const geminiClient = config.getGeminiClient(); - - const abortController = new AbortController(); - - const { processedQuery, shouldProceed } = await handleAtCommand({ - query: input, - config, - addItem: (_item, _timestamp) => 0, - onDebugMessage: () => {}, - messageId: Date.now(), - signal: abortController.signal, + return promptIdContext.run({ promptId: prompt_id }, async () => { + const consolePatcher = new ConsolePatcher({ + stderr: true, + debugMode: config.getDebugMode(), }); - if (!shouldProceed || !processedQuery) { - // An error occurred during @include processing (e.g., file not found). - // The error message is already logged by handleAtCommand. - throw new FatalInputError( - 'Exiting due to an error processing the @ command.', - ); - } - - let currentMessages: Content[] = [ - { role: 'user', parts: processedQuery as Part[] }, - ]; - - let turnCount = 0; - while (true) { - turnCount++; - if ( - config.getMaxSessionTurns() >= 0 && - turnCount > config.getMaxSessionTurns() - ) { - throw new FatalTurnLimitedError( - 'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', + try { + consolePatcher.patch(); + // Handle EPIPE errors when the output is piped to a command that closes early. + process.stdout.on('error', (err: NodeJS.ErrnoException) => { + if (err.code === 'EPIPE') { + // Exit gracefully if the pipe is closed. + process.exit(0); + } + }); + + const geminiClient = config.getGeminiClient(); + + const abortController = new AbortController(); + + const { processedQuery, shouldProceed } = await handleAtCommand({ + query: input, + config, + addItem: (_item, _timestamp) => 0, + onDebugMessage: () => {}, + messageId: Date.now(), + signal: abortController.signal, + }); + + if (!shouldProceed || !processedQuery) { + // An error occurred during @include processing (e.g., file not found). + // The error message is already logged by handleAtCommand. + throw new FatalInputError( + 'Exiting due to an error processing the @ command.', ); } - const toolCallRequests: ToolCallRequestInfo[] = []; - const responseStream = geminiClient.sendMessageStream( - currentMessages[0]?.parts || [], - abortController.signal, - prompt_id, - ); - - for await (const event of responseStream) { - if (abortController.signal.aborted) { - console.error('Operation cancelled.'); - return; + let currentMessages: Content[] = [ + { role: 'user', parts: processedQuery as Part[] }, + ]; + + let turnCount = 0; + while (true) { + turnCount++; + if ( + config.getMaxSessionTurns() >= 0 && + turnCount > config.getMaxSessionTurns() + ) { + throw new FatalTurnLimitedError( + 'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', + ); } + const toolCallRequests: ToolCallRequestInfo[] = []; - if (event.type === GeminiEventType.Content) { - process.stdout.write(event.value); - } else if (event.type === GeminiEventType.ToolCallRequest) { - toolCallRequests.push(event.value); - } - } + const responseStream = geminiClient.sendMessageStream( + currentMessages[0]?.parts || [], + abortController.signal, + prompt_id, + ); - if (toolCallRequests.length > 0) { - const toolResponseParts: Part[] = []; - for (const requestInfo of toolCallRequests) { - const toolResponse = await executeToolCall( - config, - requestInfo, - abortController.signal, - ); + for await (const event of responseStream) { + if (abortController.signal.aborted) { + console.error('Operation cancelled.'); + return; + } - if (toolResponse.error) { - console.error( - `Error executing tool ${requestInfo.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`, - ); + if (event.type === GeminiEventType.Content) { + process.stdout.write(event.value); + } else if (event.type === GeminiEventType.ToolCallRequest) { + toolCallRequests.push(event.value); } + } - if (toolResponse.responseParts) { - toolResponseParts.push(...toolResponse.responseParts); + if (toolCallRequests.length > 0) { + const toolResponseParts: Part[] = []; + for (const requestInfo of toolCallRequests) { + const toolResponse = await executeToolCall( + config, + requestInfo, + abortController.signal, + ); + + if (toolResponse.error) { + console.error( + `Error executing tool ${requestInfo.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`, + ); + } + + if (toolResponse.responseParts) { + toolResponseParts.push(...toolResponse.responseParts); + } } + currentMessages = [{ role: 'user', parts: toolResponseParts }]; + } else { + process.stdout.write('\n'); // Ensure a final newline + return; } - currentMessages = [{ role: 'user', parts: toolResponseParts }]; - } else { - process.stdout.write('\n'); // Ensure a final newline - return; + } + } catch (error) { + console.error( + parseAndFormatApiError( + error, + config.getContentGeneratorConfig()?.authType, + ), + ); + throw error; + } finally { + consolePatcher.cleanup(); + if (isTelemetrySdkInitialized()) { + await shutdownTelemetry(config); } } - } catch (error) { - console.error( - parseAndFormatApiError( - error, - config.getContentGeneratorConfig()?.authType, - ), - ); - throw error; - } finally { - consolePatcher.cleanup(); - if (isTelemetrySdkInitialized()) { - await shutdownTelemetry(config); - } - } + }); } diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index edaf33d5472..618329b4a40 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -33,6 +33,7 @@ import { parseAndFormatApiError, getCodeAssistServer, UserTierId, + promptIdContext, } from '@google/gemini-cli-core'; import { type Part, type PartListUnion, FinishReason } from '@google/genai'; import type { @@ -705,71 +706,72 @@ export const useGeminiStream = ( if (!prompt_id) { prompt_id = config.getSessionId() + '########' + getPromptCount(); } - - const { queryToSend, shouldProceed } = await prepareQueryForGemini( - query, - userMessageTimestamp, - abortSignal, - prompt_id!, - ); - - if (!shouldProceed || queryToSend === null) { - return; - } - - if (!options?.isContinuation) { - startNewPrompt(); - setThought(null); // Reset thought when starting a new prompt - } - - setIsResponding(true); - setInitError(null); - - try { - const stream = geminiClient.sendMessageStream( - queryToSend, - abortSignal, - prompt_id!, - ); - const processingStatus = await processGeminiStreamEvents( - stream, + return promptIdContext.run({ promptId: prompt_id! }, async () => { + const { queryToSend, shouldProceed } = await prepareQueryForGemini( + query, userMessageTimestamp, abortSignal, + prompt_id!, ); - if (processingStatus === StreamProcessingStatus.UserCancelled) { + if (!shouldProceed || queryToSend === null) { return; } - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - setPendingHistoryItem(null); + if (!options?.isContinuation) { + startNewPrompt(); + setThought(null); // Reset thought when starting a new prompt } - if (loopDetectedRef.current) { - loopDetectedRef.current = false; - handleLoopDetectedEvent(); - } - } catch (error: unknown) { - if (error instanceof UnauthorizedError) { - onAuthError('Session expired or is unauthorized.'); - } else if (!isNodeError(error) || error.name !== 'AbortError') { - addItem( - { - type: MessageType.ERROR, - text: parseAndFormatApiError( - getErrorMessage(error) || 'Unknown error', - config.getContentGeneratorConfig()?.authType, - undefined, - config.getModel(), - DEFAULT_GEMINI_FLASH_MODEL, - ), - }, + + setIsResponding(true); + setInitError(null); + + try { + const stream = geminiClient.sendMessageStream( + queryToSend, + abortSignal, + prompt_id!, + ); + const processingStatus = await processGeminiStreamEvents( + stream, userMessageTimestamp, + abortSignal, ); + + if (processingStatus === StreamProcessingStatus.UserCancelled) { + return; + } + + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + setPendingHistoryItem(null); + } + if (loopDetectedRef.current) { + loopDetectedRef.current = false; + handleLoopDetectedEvent(); + } + } catch (error: unknown) { + if (error instanceof UnauthorizedError) { + onAuthError('Session expired or is unauthorized.'); + } else if (!isNodeError(error) || error.name !== 'AbortError') { + addItem( + { + type: MessageType.ERROR, + text: parseAndFormatApiError( + getErrorMessage(error) || 'Unknown error', + config.getContentGeneratorConfig()?.authType, + undefined, + config.getModel(), + DEFAULT_GEMINI_FLASH_MODEL, + ), + }, + userMessageTimestamp, + ); + } + } finally { + setIsResponding(false); } - } finally { - setIsResponding(false); - } + }); }, [ streamingState, diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index ecdf7c7b03f..e5e10d98a39 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -780,58 +780,56 @@ describe('setApprovalMode with folder trust', () => { }); describe('LlmUtilityService Lifecycle', () => { - const MODEL = 'gemini-pro'; - const SANDBOX: SandboxConfig = { - command: 'docker', - image: 'gemini-cli-sandbox', - }; - const TARGET_DIR = '/path/to/target'; - const DEBUG_MODE = false; - const QUESTION = 'test question'; - const FULL_CONTEXT = false; - const USER_MEMORY = 'Test User Memory'; - const TELEMETRY_SETTINGS = { enabled: false }; - const EMBEDDING_MODEL = 'gemini-embedding'; - const SESSION_ID = 'test-session-id'; - const baseParams: ConfigParameters = { - cwd: '/tmp', - embeddingModel: EMBEDDING_MODEL, - sandbox: SANDBOX, - targetDir: TARGET_DIR, - debugMode: DEBUG_MODE, - question: QUESTION, - fullContext: FULL_CONTEXT, - userMemory: USER_MEMORY, - telemetry: TELEMETRY_SETTINGS, - sessionId: SESSION_ID, - model: MODEL, - usageStatisticsEnabled: false, - }; - - it('should throw an error if getLlmUtilityService is called before refreshAuth', () => { - const config = new Config(baseParams); - expect(() => config.getLlmUtilityService()).toThrow( - 'LlmUtilityService not initialized. Ensure authentication has occurred and GeminiClient is ready.', - ); - }); + const MODEL = 'gemini-pro'; + const SANDBOX: SandboxConfig = { + command: 'docker', + image: 'gemini-cli-sandbox', + }; + const TARGET_DIR = '/path/to/target'; + const DEBUG_MODE = false; + const QUESTION = 'test question'; + const FULL_CONTEXT = false; + const USER_MEMORY = 'Test User Memory'; + const TELEMETRY_SETTINGS = { enabled: false }; + const EMBEDDING_MODEL = 'gemini-embedding'; + const SESSION_ID = 'test-session-id'; + const baseParams: ConfigParameters = { + cwd: '/tmp', + embeddingModel: EMBEDDING_MODEL, + sandbox: SANDBOX, + targetDir: TARGET_DIR, + debugMode: DEBUG_MODE, + question: QUESTION, + fullContext: FULL_CONTEXT, + userMemory: USER_MEMORY, + telemetry: TELEMETRY_SETTINGS, + sessionId: SESSION_ID, + model: MODEL, + usageStatisticsEnabled: false, + }; - it('should successfully initialize LlmUtilityService after refreshAuth is called', async () => { - const config = new Config(baseParams); - const authType = AuthType.USE_GEMINI; - const mockContentConfig = { model: 'gemini-flash', apiKey: 'test-key' }; + it('should throw an error if getLlmUtilityService is called before refreshAuth', () => { + const config = new Config(baseParams); + expect(() => config.getLlmUtilityService()).toThrow( + 'LlmUtilityService not initialized. Ensure authentication has occurred and GeminiClient is ready.', + ); + }); - vi.mocked(createContentGeneratorConfig).mockReturnValue( - mockContentConfig, - ); + it('should successfully initialize LlmUtilityService after refreshAuth is called', async () => { + const config = new Config(baseParams); + const authType = AuthType.USE_GEMINI; + const mockContentConfig = { model: 'gemini-flash', apiKey: 'test-key' }; - await config.refreshAuth(authType); + vi.mocked(createContentGeneratorConfig).mockReturnValue(mockContentConfig); - // Should not throw - const llmService = config.getLlmUtilityService(); - expect(llmService).toBeDefined(); - expect(LlmUtilityService).toHaveBeenCalledWith( - config.getContentGenerator(), - config, - ); - }); + await config.refreshAuth(authType); + + // Should not throw + const llmService = config.getLlmUtilityService(); + expect(llmService).toBeDefined(); + expect(LlmUtilityService).toHaveBeenCalledWith( + config.getContentGenerator(), + config, + ); }); +}); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index a022df8e687..a38dd9d18bf 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -458,10 +458,7 @@ export class Config { this.contentGeneratorConfig = newContentGeneratorConfig; // Initialize LlmUtilityService now that the ContentGenerator is available - this.llmUtilityService = new LlmUtilityService( - this.contentGenerator, - this, - ); + this.llmUtilityService = new LlmUtilityService(this.contentGenerator, this); // Reset the session flag since we're explicitly changing auth and using default model this.inFallbackMode = false; diff --git a/packages/core/src/core/llmUtilityService.test.ts b/packages/core/src/core/llmUtilityService.test.ts index 390f5fbd1bb..03f1109c054 100644 --- a/packages/core/src/core/llmUtilityService.test.ts +++ b/packages/core/src/core/llmUtilityService.test.ts @@ -15,7 +15,10 @@ import { } from 'vitest'; import type { GenerateContentResponse } from '@google/genai'; -import { LlmUtilityService, type GenerateJsonOptions } from './llmUtilityService.js'; +import { + LlmUtilityService, + type GenerateJsonOptions, +} from './llmUtilityService.js'; import type { ContentGenerator } from './contentGenerator.js'; import type { Config } from '../config/config.js'; import { AuthType } from './contentGenerator.js'; @@ -25,15 +28,14 @@ import { retryWithBackoff } from '../utils/retry.js'; import { MalformedJsonResponseEvent } from '../telemetry/types.js'; import { getErrorMessage } from '../utils/errors.js'; - vi.mock('../utils/errorReporting.js'); vi.mock('../telemetry/loggers.js'); vi.mock('../utils/errors.js', async (importOriginal) => { - const actual = await importOriginal(); - return { - ...actual, - getErrorMessage: vi.fn((e) => (e instanceof Error ? e.message : String(e))), - }; + const actual = await importOriginal(); + return { + ...actual, + getErrorMessage: vi.fn((e) => (e instanceof Error ? e.message : String(e))), + }; }); vi.mock('../utils/retry.js', () => ({ @@ -48,13 +50,16 @@ const mockContentGenerator = { const mockConfig = { getSessionId: vi.fn().mockReturnValue('test-session-id'), - getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: AuthType.USE_GEMINI }), + getContentGeneratorConfig: vi + .fn() + .mockReturnValue({ authType: AuthType.USE_GEMINI }), } as unknown as Mocked; // Helper to create a mock GenerateContentResponse -const createMockResponse = (text: string): GenerateContentResponse => ({ - candidates: [{ content: { role: 'model', parts: [{ text }] }, index: 0 }], -} as GenerateContentResponse); +const createMockResponse = (text: string): GenerateContentResponse => + ({ + candidates: [{ content: { role: 'model', parts: [{ text }] }, index: 0 }], + }) as GenerateContentResponse; describe('LlmUtilityService', () => { let service: LlmUtilityService; @@ -64,7 +69,9 @@ describe('LlmUtilityService', () => { beforeEach(() => { vi.clearAllMocks(); // Reset the mocked implementation for getErrorMessage for accurate error message assertions - vi.mocked(getErrorMessage).mockImplementation((e) => (e instanceof Error ? e.message : String(e))); + vi.mocked(getErrorMessage).mockImplementation((e) => + e instanceof Error ? e.message : String(e), + ); service = new LlmUtilityService(mockContentGenerator, mockConfig); abortController = new AbortController(); defaultOptions = { @@ -91,7 +98,7 @@ describe('LlmUtilityService', () => { // Ensure the retry mechanism was engaged expect(retryWithBackoff).toHaveBeenCalledTimes(1); - + // Validate the parameters passed to the underlying generator expect(mockGenerateContent).toHaveBeenCalledTimes(1); expect(mockGenerateContent).toHaveBeenCalledWith( @@ -100,8 +107,8 @@ describe('LlmUtilityService', () => { contents: defaultOptions.contents, config: { abortSignal: defaultOptions.abortSignal, - temperature: 0, - topP: 1, + temperature: 0, + topP: 1, responseJsonSchema: defaultOptions.schema, responseMimeType: 'application/json', // Crucial: systemInstruction should NOT be in the config object if not provided @@ -157,51 +164,56 @@ describe('LlmUtilityService', () => { }); it('should use the provided promptId', async () => { - const mockResponse = createMockResponse('{"color": "yellow"}'); - mockGenerateContent.mockResolvedValue(mockResponse); - const customPromptId = 'custom-id-123'; - - const options: GenerateJsonOptions = { - ...defaultOptions, - promptId: customPromptId, - }; - - await service.generateJson(options); - - expect(mockGenerateContent).toHaveBeenCalledWith( - expect.any(Object), - customPromptId, - ); - }); + const mockResponse = createMockResponse('{"color": "yellow"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + const customPromptId = 'custom-id-123'; + + const options: GenerateJsonOptions = { + ...defaultOptions, + promptId: customPromptId, + }; + + await service.generateJson(options); + + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.any(Object), + customPromptId, + ); + }); }); describe('generateJson - Response Cleaning', () => { it('should clean JSON wrapped in markdown backticks and log telemetry', async () => { const malformedResponse = '```json\n{"color": "purple"}\n```'; - mockGenerateContent.mockResolvedValue(createMockResponse(malformedResponse)); + mockGenerateContent.mockResolvedValue( + createMockResponse(malformedResponse), + ); const result = await service.generateJson(defaultOptions); expect(result).toEqual({ color: 'purple' }); expect(logMalformedJsonResponse).toHaveBeenCalledTimes(1); expect(logMalformedJsonResponse).toHaveBeenCalledWith( - mockConfig, - expect.any(MalformedJsonResponseEvent) + mockConfig, + expect.any(MalformedJsonResponseEvent), ); // Validate the telemetry event content - const event = vi.mocked(logMalformedJsonResponse).mock.calls[0][1] as MalformedJsonResponseEvent; + const event = vi.mocked(logMalformedJsonResponse).mock + .calls[0][1] as MalformedJsonResponseEvent; expect(event.model).toBe('test-model'); }); it('should handle extra whitespace correctly without logging malformed telemetry', async () => { - const responseWithWhitespace = ' \n {"color": "orange"} \n'; - mockGenerateContent.mockResolvedValue(createMockResponse(responseWithWhitespace)); - - const result = await service.generateJson(defaultOptions); - - expect(result).toEqual({ color: 'orange' }); - expect(logMalformedJsonResponse).not.toHaveBeenCalled(); - }); + const responseWithWhitespace = ' \n {"color": "orange"} \n'; + mockGenerateContent.mockResolvedValue( + createMockResponse(responseWithWhitespace), + ); + + const result = await service.generateJson(defaultOptions); + + expect(result).toEqual({ color: 'orange' }); + expect(logMalformedJsonResponse).not.toHaveBeenCalled(); + }); }); describe('generateJson - Error Handling', () => { @@ -210,7 +222,7 @@ describe('LlmUtilityService', () => { // The final error message includes the prefix added by the service's outer catch block. await expect(service.generateJson(defaultOptions)).rejects.toThrow( - 'Failed to generate JSON content: API returned an empty response for generateJson.' + 'Failed to generate JSON content: API returned an empty response for generateJson.', ); // Verify error reporting details @@ -219,7 +231,7 @@ describe('LlmUtilityService', () => { expect.any(Error), 'Error in generateJson: API returned an empty response.', defaultOptions.contents, - 'generateJson-empty-response' + 'generateJson-empty-response', ); }); @@ -228,15 +240,15 @@ describe('LlmUtilityService', () => { mockGenerateContent.mockResolvedValue(createMockResponse(invalidJson)); await expect(service.generateJson(defaultOptions)).rejects.toThrow( - /^Failed to generate JSON content: Failed to parse API response as JSON:/ + /^Failed to generate JSON content: Failed to parse API response as JSON:/, ); expect(reportError).toHaveBeenCalledTimes(1); expect(reportError).toHaveBeenCalledWith( - expect.any(Error), + expect.any(Error), 'Failed to parse JSON response from generateJson.', expect.objectContaining({ responseTextFailedToParse: invalidJson }), - 'generateJson-parse' + 'generateJson-parse', ); }); @@ -246,7 +258,7 @@ describe('LlmUtilityService', () => { mockGenerateContent.mockRejectedValue(apiError); await expect(service.generateJson(defaultOptions)).rejects.toThrow( - 'Failed to generate JSON content: Service Unavailable (503)' + 'Failed to generate JSON content: Service Unavailable (503)', ); // Verify generic error reporting @@ -255,25 +267,28 @@ describe('LlmUtilityService', () => { apiError, 'Error generating JSON content via API.', defaultOptions.contents, - 'generateJson-api' + 'generateJson-api', ); }); it('should throw immediately without reporting if aborted', async () => { - const abortError = new DOMException('Aborted', 'AbortError'); - - // Simulate abortion happening during the API call - mockGenerateContent.mockImplementation(() => { - abortController.abort(); // Ensure the signal is aborted when the service checks - throw abortError; - }); - - const options = { ...defaultOptions, abortSignal: abortController.signal }; - - await expect(service.generateJson(options)).rejects.toThrow(abortError); - - // Crucially, it should not report a cancellation as an application error - expect(reportError).not.toHaveBeenCalled(); + const abortError = new DOMException('Aborted', 'AbortError'); + + // Simulate abortion happening during the API call + mockGenerateContent.mockImplementation(() => { + abortController.abort(); // Ensure the signal is aborted when the service checks + throw abortError; + }); + + const options = { + ...defaultOptions, + abortSignal: abortController.signal, + }; + + await expect(service.generateJson(options)).rejects.toThrow(abortError); + + // Crucially, it should not report a cancellation as an application error + expect(reportError).not.toHaveBeenCalled(); }); }); -}); \ No newline at end of file +}); diff --git a/packages/core/src/core/llmUtilityService.ts b/packages/core/src/core/llmUtilityService.ts index 8dcb50d09e1..96fdeb5b55c 100644 --- a/packages/core/src/core/llmUtilityService.ts +++ b/packages/core/src/core/llmUtilityService.ts @@ -4,11 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { - Content, - GenerateContentConfig, - Part, -} from '@google/genai'; +import type { Content, GenerateContentConfig, Part } from '@google/genai'; import type { Config } from '../config/config.js'; import type { ContentGenerator } from './contentGenerator.js'; import { getResponseText } from '../utils/partUtils.js'; @@ -36,10 +32,17 @@ export interface GenerateJsonOptions { /** * Overrides for generation configuration (e.g., temperature). */ - config?: Omit; + config?: Omit< + GenerateContentConfig, + | 'systemInstruction' + | 'responseJsonSchema' + | 'responseMimeType' + | 'tools' + | 'abortSignal' + >; /** Signal for cancellation. */ abortSignal: AbortSignal; - /** + /** * A unique ID for the prompt, used for logging/telemetry correlation. */ promptId: string; @@ -60,8 +63,17 @@ export class LlmUtilityService { private readonly config: Config, ) {} - async generateJson(options: GenerateJsonOptions): Promise> { - const { contents, schema, model, abortSignal, systemInstruction, promptId } = options; + async generateJson( + options: GenerateJsonOptions, + ): Promise> { + const { + contents, + schema, + model, + abortSignal, + systemInstruction, + promptId, + } = options; const requestConfig: GenerateContentConfig = { abortSignal, @@ -73,20 +85,23 @@ export class LlmUtilityService { }; try { - const apiCall = () => this.contentGenerator.generateContent( - { - model, - config: requestConfig, - contents, - }, - promptId, - ); + const apiCall = () => + this.contentGenerator.generateContent( + { + model, + config: requestConfig, + contents, + }, + promptId, + ); const result = await retryWithBackoff(apiCall); let text = getResponseText(result)?.trim(); if (!text) { - const error = new Error('API returned an empty response for generateJson.'); + const error = new Error( + 'API returned an empty response for generateJson.', + ); await reportError( error, 'Error in generateJson: API returned an empty response.', @@ -101,7 +116,9 @@ export class LlmUtilityService { try { return JSON.parse(text); } catch (parseError) { - const error = new Error(`Failed to parse API response as JSON: ${getErrorMessage(parseError)}`); + const error = new Error( + `Failed to parse API response as JSON: ${getErrorMessage(parseError)}`, + ); await reportError( parseError, 'Failed to parse JSON response from generateJson.', @@ -113,40 +130,42 @@ export class LlmUtilityService { ); throw error; } - } catch (error) { - if (abortSignal.aborted) { + if (abortSignal.aborted) { throw error; } - if (error instanceof Error && ( - error.message === 'API returned an empty response for generateJson.' || - error.message.startsWith('Failed to parse API response as JSON:') - )) { - // We perform this check so that we don't report these again. + if ( + error instanceof Error && + (error.message === 'API returned an empty response for generateJson.' || + error.message.startsWith('Failed to parse API response as JSON:')) + ) { + // We perform this check so that we don't report these again. } else { await reportError( - error, - 'Error generating JSON content via API.', - contents, - 'generateJson-api', + error, + 'Error generating JSON content via API.', + contents, + 'generateJson-api', ); } - throw new Error(`Failed to generate JSON content: ${getErrorMessage(error)}`); + throw new Error( + `Failed to generate JSON content: ${getErrorMessage(error)}`, + ); } } private cleanJsonResponse(text: string, model: string): string { - const prefix = '```json'; - const suffix = '```'; - if (text.startsWith(prefix) && text.endsWith(suffix)) { - logMalformedJsonResponse( - this.config, - new MalformedJsonResponseEvent(model), - ); - return text.substring(prefix.length, text.length - suffix.length).trim(); - } - return text; + const prefix = '```json'; + const suffix = '```'; + if (text.startsWith(prefix) && text.endsWith(suffix)) { + logMalformedJsonResponse( + this.config, + new MalformedJsonResponseEvent(model), + ); + return text.substring(prefix.length, text.length - suffix.length).trim(); + } + return text; } } diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 047e43a5298..aa49b0df304 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -50,6 +50,7 @@ export * from './utils/workspaceContext.js'; export * from './utils/ignorePatterns.js'; export * from './utils/partUtils.js'; export * from './utils/ide-trust.js'; +export * from './utils/promptIdContext.js'; // Export services export * from './services/fileDiscoveryService.js'; diff --git a/packages/core/src/tools/smart-edit.test.ts b/packages/core/src/tools/smart-edit.test.ts index 132d9933067..6d1fb3d93f5 100644 --- a/packages/core/src/tools/smart-edit.test.ts +++ b/packages/core/src/tools/smart-edit.test.ts @@ -60,6 +60,7 @@ import { ApprovalMode, type Config } from '../config/config.js'; import { type Content, type Part, type SchemaUnion } from '@google/genai'; import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js'; import { StandardFileSystemService } from '../services/fileSystemService.js'; +import type { LlmUtilityService } from '../core/llmUtilityService.js'; describe('SmartEditTool', () => { let tool: SmartEditTool; @@ -67,6 +68,7 @@ describe('SmartEditTool', () => { let rootDir: string; let mockConfig: Config; let geminiClient: any; + let llmUtilityService: LlmUtilityService; beforeEach(() => { vi.restoreAllMocks(); @@ -78,8 +80,13 @@ describe('SmartEditTool', () => { generateJson: mockGenerateJson, }; + llmUtilityService = { + generateJson: mockGenerateJson, + } as unknown as LlmUtilityService; + mockConfig = { getGeminiClient: vi.fn().mockReturnValue(geminiClient), + getLlmUtilityService: vi.fn().mockReturnValue(llmUtilityService), getTargetDir: () => rootDir, getApprovalMode: vi.fn(), setApprovalMode: vi.fn(), diff --git a/packages/core/src/tools/smart-edit.ts b/packages/core/src/tools/smart-edit.ts index 3647b73246b..920222ba452 100644 --- a/packages/core/src/tools/smart-edit.ts +++ b/packages/core/src/tools/smart-edit.ts @@ -310,7 +310,7 @@ class EditToolInvocation implements ToolInvocation { params.new_string, initialError.raw, currentContent, - this.config.getGeminiClient(), + this.config.getLlmUtilityService(), abortSignal, ); diff --git a/packages/core/src/utils/llm-edit-fixer.test.ts b/packages/core/src/utils/llm-edit-fixer.test.ts new file mode 100644 index 00000000000..eaf6c9ba968 --- /dev/null +++ b/packages/core/src/utils/llm-edit-fixer.test.ts @@ -0,0 +1,204 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + FixLLMEditWithInstruction, + resetLlmEditFixerCaches_TEST_ONLY, + type SearchReplaceEdit, +} from './llm-edit-fixer.js'; +import { promptIdContext } from './promptIdContext.js'; +import type { LlmUtilityService } from '../core/llmUtilityService.js'; + +// Mock the LlmUtilityService +const mockGenerateJson = vi.fn(); +const mockLlmUtilityService = { + generateJson: mockGenerateJson, +} as unknown as LlmUtilityService; + +describe('FixLLMEditWithInstruction', () => { + const instruction = 'Replace the title'; + const old_string = '

Old Title

'; + const new_string = '

New Title

'; + const error = 'String not found'; + const current_content = '

Old Title

'; + const abortController = new AbortController(); + const abortSignal = abortController.signal; + + beforeEach(() => { + vi.clearAllMocks(); + resetLlmEditFixerCaches_TEST_ONLY(); // Ensure cache is cleared before each test + }); + + afterEach(() => { + vi.useRealTimers(); // Reset timers after each test + }); + + const mockApiResponse: SearchReplaceEdit = { + search: '

Old Title

', + replace: '

New Title

', + noChangesRequired: false, + explanation: 'The original search was correct.', + }; + + it('should use the promptId from the AsyncLocalStorage context when available', async () => { + const testPromptId = 'test-prompt-id-12345'; + mockGenerateJson.mockResolvedValue(mockApiResponse); + + // Run the function within the context + await promptIdContext.run({ promptId: testPromptId }, async () => { + await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockLlmUtilityService, + abortSignal, + ); + }); + + // Verify that generateJson was called with the promptId from the context + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(mockGenerateJson).toHaveBeenCalledWith( + expect.objectContaining({ + promptId: testPromptId, + }), + ); + }); + + it('should generate and use a fallback promptId when context is not available', async () => { + mockGenerateJson.mockResolvedValue(mockApiResponse); + const consoleWarnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => {}); + + // Run the function outside of any context + await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockLlmUtilityService, + abortSignal, + ); + + // Verify the warning was logged + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining( + 'Could not find promptId in context. This is unexpected. Using a fallback ID: llm-fixer-fallback-', + ), + ); + + // Verify that generateJson was called with the generated fallback promptId + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(mockGenerateJson).toHaveBeenCalledWith( + expect.objectContaining({ + promptId: expect.stringContaining('llm-fixer-fallback-'), + }), + ); + + // Restore mocks + consoleWarnSpy.mockRestore(); + }); + + it('should construct the user prompt correctly', async () => { + mockGenerateJson.mockResolvedValue(mockApiResponse); + const testPromptId = 'test-prompt-id-prompt-construction'; + + await promptIdContext.run({ promptId: testPromptId }, async () => { + await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockLlmUtilityService, + abortSignal, + ); + }); + + const generateJsonCall = mockGenerateJson.mock.calls[0][0]; + const userPromptContent = generateJsonCall.contents[0].parts[0].text; + + expect(userPromptContent).toContain( + `\n${instruction}\n`, + ); + expect(userPromptContent).toContain(`\n${old_string}\n`); + expect(userPromptContent).toContain(`\n${new_string}\n`); + expect(userPromptContent).toContain(`\n${error}\n`); + expect(userPromptContent).toContain( + `\n${current_content}\n`, + ); + }); + + it('should return a cached result on subsequent identical calls', async () => { + mockGenerateJson.mockResolvedValue(mockApiResponse); + const testPromptId = 'test-prompt-id-caching'; + + await promptIdContext.run({ promptId: testPromptId }, async () => { + // First call - should call the API + const result1 = await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockLlmUtilityService, + abortSignal, + ); + + // Second call with identical parameters - should hit the cache + const result2 = await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockLlmUtilityService, + abortSignal, + ); + + expect(result1).toEqual(mockApiResponse); + expect(result2).toEqual(mockApiResponse); + // Verify the underlying service was only called ONCE + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + }); + }); + + it('should not use cache for calls with different parameters', async () => { + mockGenerateJson.mockResolvedValue(mockApiResponse); + const testPromptId = 'test-prompt-id-cache-miss'; + + await promptIdContext.run({ promptId: testPromptId }, async () => { + // First call + await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockLlmUtilityService, + abortSignal, + ); + + // Second call with a different instruction + await FixLLMEditWithInstruction( + 'A different instruction', + old_string, + new_string, + error, + current_content, + mockLlmUtilityService, + abortSignal, + ); + + // Verify the underlying service was called TWICE + expect(mockGenerateJson).toHaveBeenCalledTimes(2); + }); + }); +}); diff --git a/packages/core/src/utils/llm-edit-fixer.ts b/packages/core/src/utils/llm-edit-fixer.ts index 95496d47794..861c6579f0f 100644 --- a/packages/core/src/utils/llm-edit-fixer.ts +++ b/packages/core/src/utils/llm-edit-fixer.ts @@ -5,9 +5,10 @@ */ import { type Content, Type } from '@google/genai'; -import { type GeminiClient } from '../core/client.js'; +import { type LlmUtilityService } from '../core/llmUtilityService.js'; import { LruCache } from './LruCache.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { promptIdContext } from './promptIdContext.js'; const MAX_CACHE_SIZE = 50; @@ -93,8 +94,9 @@ const editCorrectionWithInstructionCache = new LruCache< * @param new_string The original replacement string. * @param error The error that occurred during the initial edit. * @param current_content The current content of the file. - * @param geminiClient The Gemini client to use for the LLM call. + * @param llmUtilityService The LlmUtilityService to use for the LLM call. * @param abortSignal An abort signal to cancel the operation. + * @param promptId A unique ID for the prompt. * @returns A new search and replace pair. */ export async function FixLLMEditWithInstruction( @@ -103,9 +105,20 @@ export async function FixLLMEditWithInstruction( new_string: string, error: string, current_content: string, - geminiClient: GeminiClient, + llmUtilityService: LlmUtilityService, abortSignal: AbortSignal, ): Promise { + let promptId: string; + const context = promptIdContext.getStore(); + if (!context) { + promptId = `llm-fixer-fallback-${Date.now()}-${Math.random().toString(16).slice(2)}`; + console.warn( + `Could not find promptId in context. This is unexpected. Using a fallback ID: ${promptId}`, + ); + } else { + promptId = context.promptId; + } + const cacheKey = `${instruction}---${old_string}---${new_string}--${current_content}--${error}`; const cachedResult = editCorrectionWithInstructionCache.get(cacheKey); if (cachedResult) { @@ -120,21 +133,18 @@ export async function FixLLMEditWithInstruction( const contents: Content[] = [ { role: 'user', - parts: [ - { - text: `${EDIT_SYS_PROMPT} -${userPrompt}`, - }, - ], + parts: [{ text: userPrompt }], }, ]; - const result = (await geminiClient.generateJson( + const result = (await llmUtilityService.generateJson({ contents, - SearchReplaceEditSchema, + schema: SearchReplaceEditSchema, abortSignal, - DEFAULT_GEMINI_FLASH_MODEL, - )) as unknown as SearchReplaceEdit; + model: DEFAULT_GEMINI_FLASH_MODEL, + systemInstruction: EDIT_SYS_PROMPT, + promptId, + })) as unknown as SearchReplaceEdit; editCorrectionWithInstructionCache.set(cacheKey, result); return result; diff --git a/packages/core/src/utils/promptIdContext.ts b/packages/core/src/utils/promptIdContext.ts new file mode 100644 index 00000000000..22f501904ff --- /dev/null +++ b/packages/core/src/utils/promptIdContext.ts @@ -0,0 +1,13 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { AsyncLocalStorage } from 'node:async_hooks'; + +interface PromptContext { + promptId: string; +} + +export const promptIdContext = new AsyncLocalStorage(); From f3c9ce75a1992f11e9e742dbfc29c870c68b8ce8 Mon Sep 17 00:00:00 2001 From: Abhi Date: Mon, 8 Sep 2025 20:08:11 -0400 Subject: [PATCH 3/3] Address Jacob's comments --- packages/cli/src/nonInteractiveCli.ts | 2 +- packages/cli/src/ui/hooks/useGeminiStream.ts | 6 ++-- packages/core/src/config/config.test.ts | 18 +++++----- packages/core/src/config/config.ts | 20 +++++------ ...yService.test.ts => baseLlmClient.test.ts} | 33 +++++++++---------- ...{llmUtilityService.ts => baseLlmClient.ts} | 4 +-- packages/core/src/tools/smart-edit.test.ts | 10 +++--- packages/core/src/tools/smart-edit.ts | 2 +- .../core/src/utils/llm-edit-fixer.test.ts | 33 +++++++++---------- packages/core/src/utils/llm-edit-fixer.ts | 15 ++++----- packages/core/src/utils/promptIdContext.ts | 6 +--- 11 files changed, 69 insertions(+), 80 deletions(-) rename packages/core/src/core/{llmUtilityService.test.ts => baseLlmClient.test.ts} (91%) rename packages/core/src/core/{llmUtilityService.ts => baseLlmClient.ts} (97%) diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index 4adcd93f305..ff33bd86ec6 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -25,7 +25,7 @@ export async function runNonInteractive( input: string, prompt_id: string, ): Promise { - return promptIdContext.run({ promptId: prompt_id }, async () => { + return promptIdContext.run(prompt_id, async () => { const consolePatcher = new ConsolePatcher({ stderr: true, debugMode: config.getDebugMode(), diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 618329b4a40..ac92068edb4 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -706,12 +706,12 @@ export const useGeminiStream = ( if (!prompt_id) { prompt_id = config.getSessionId() + '########' + getPromptCount(); } - return promptIdContext.run({ promptId: prompt_id! }, async () => { + return promptIdContext.run(prompt_id, async () => { const { queryToSend, shouldProceed } = await prepareQueryForGemini( query, userMessageTimestamp, abortSignal, - prompt_id!, + prompt_id, ); if (!shouldProceed || queryToSend === null) { @@ -730,7 +730,7 @@ export const useGeminiStream = ( const stream = geminiClient.sendMessageStream( queryToSend, abortSignal, - prompt_id!, + prompt_id, ); const processingStatus = await processGeminiStreamEvents( stream, diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index e5e10d98a39..81096b69d50 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -122,9 +122,9 @@ vi.mock('../ide/ide-client.js', () => ({ }, })); -import { LlmUtilityService } from '../core/llmUtilityService.js'; +import { BaseLlmClient } from '../core/baseLlmClient.js'; -vi.mock('../core/llmUtilityService.js'); +vi.mock('../core/baseLlmClient.js'); describe('Server Config (config.ts)', () => { const MODEL = 'gemini-pro'; @@ -779,7 +779,7 @@ describe('setApprovalMode with folder trust', () => { }); }); -describe('LlmUtilityService Lifecycle', () => { +describe('BaseLlmClient Lifecycle', () => { const MODEL = 'gemini-pro'; const SANDBOX: SandboxConfig = { command: 'docker', @@ -808,14 +808,14 @@ describe('LlmUtilityService Lifecycle', () => { usageStatisticsEnabled: false, }; - it('should throw an error if getLlmUtilityService is called before refreshAuth', () => { + it('should throw an error if getBaseLlmClient is called before refreshAuth', () => { const config = new Config(baseParams); - expect(() => config.getLlmUtilityService()).toThrow( - 'LlmUtilityService not initialized. Ensure authentication has occurred and GeminiClient is ready.', + expect(() => config.getBaseLlmClient()).toThrow( + 'BaseLlmClient not initialized. Ensure authentication has occurred and ContentGenerator is ready.', ); }); - it('should successfully initialize LlmUtilityService after refreshAuth is called', async () => { + it('should successfully initialize BaseLlmClient after refreshAuth is called', async () => { const config = new Config(baseParams); const authType = AuthType.USE_GEMINI; const mockContentConfig = { model: 'gemini-flash', apiKey: 'test-key' }; @@ -825,9 +825,9 @@ describe('LlmUtilityService Lifecycle', () => { await config.refreshAuth(authType); // Should not throw - const llmService = config.getLlmUtilityService(); + const llmService = config.getBaseLlmClient(); expect(llmService).toBeDefined(); - expect(LlmUtilityService).toHaveBeenCalledWith( + expect(BaseLlmClient).toHaveBeenCalledWith( config.getContentGenerator(), config, ); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index a38dd9d18bf..c4dfc3a56ce 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -31,7 +31,7 @@ import { ReadManyFilesTool } from '../tools/read-many-files.js'; import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js'; import { WebSearchTool } from '../tools/web-search.js'; import { GeminiClient } from '../core/client.js'; -import { LlmUtilityService } from '../core/llmUtilityService.js'; +import { BaseLlmClient } from '../core/baseLlmClient.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { GitService } from '../services/gitService.js'; import type { TelemetryTarget } from '../telemetry/index.js'; @@ -258,7 +258,7 @@ export class Config { private readonly telemetrySettings: TelemetrySettings; private readonly usageStatisticsEnabled: boolean; private geminiClient!: GeminiClient; - private llmUtilityService!: LlmUtilityService; + private baseLlmClient!: BaseLlmClient; private readonly fileFiltering: { respectGitIgnore: boolean; respectGeminiIgnore: boolean; @@ -457,8 +457,8 @@ export class Config { // Only assign to instance properties after successful initialization this.contentGeneratorConfig = newContentGeneratorConfig; - // Initialize LlmUtilityService now that the ContentGenerator is available - this.llmUtilityService = new LlmUtilityService(this.contentGenerator, this); + // Initialize BaseLlmClient now that the ContentGenerator is available + this.baseLlmClient = new BaseLlmClient(this.contentGenerator, this); // Reset the session flag since we're explicitly changing auth and using default model this.inFallbackMode = false; @@ -469,23 +469,23 @@ export class Config { } /** - * Provides access to the LlmUtilityService for stateless LLM operations. + * Provides access to the BaseLlmClient for stateless LLM operations. */ - getLlmUtilityService(): LlmUtilityService { - if (!this.llmUtilityService) { + getBaseLlmClient(): BaseLlmClient { + if (!this.baseLlmClient) { // Handle cases where initialization might be deferred or authentication failed if (this.contentGenerator) { - this.llmUtilityService = new LlmUtilityService( + this.baseLlmClient = new BaseLlmClient( this.getContentGenerator(), this, ); } else { throw new Error( - 'LlmUtilityService not initialized. Ensure authentication has occurred and GeminiClient is ready.', + 'BaseLlmClient not initialized. Ensure authentication has occurred and ContentGenerator is ready.', ); } } - return this.llmUtilityService; + return this.baseLlmClient; } getSessionId(): string { diff --git a/packages/core/src/core/llmUtilityService.test.ts b/packages/core/src/core/baseLlmClient.test.ts similarity index 91% rename from packages/core/src/core/llmUtilityService.test.ts rename to packages/core/src/core/baseLlmClient.test.ts index 03f1109c054..1b1787f5fd4 100644 --- a/packages/core/src/core/llmUtilityService.test.ts +++ b/packages/core/src/core/baseLlmClient.test.ts @@ -15,10 +15,7 @@ import { } from 'vitest'; import type { GenerateContentResponse } from '@google/genai'; -import { - LlmUtilityService, - type GenerateJsonOptions, -} from './llmUtilityService.js'; +import { BaseLlmClient, type GenerateJsonOptions } from './baseLlmClient.js'; import type { ContentGenerator } from './contentGenerator.js'; import type { Config } from '../config/config.js'; import { AuthType } from './contentGenerator.js'; @@ -61,8 +58,8 @@ const createMockResponse = (text: string): GenerateContentResponse => candidates: [{ content: { role: 'model', parts: [{ text }] }, index: 0 }], }) as GenerateContentResponse; -describe('LlmUtilityService', () => { - let service: LlmUtilityService; +describe('BaseLlmClient', () => { + let client: BaseLlmClient; let abortController: AbortController; let defaultOptions: GenerateJsonOptions; @@ -72,7 +69,7 @@ describe('LlmUtilityService', () => { vi.mocked(getErrorMessage).mockImplementation((e) => e instanceof Error ? e.message : String(e), ); - service = new LlmUtilityService(mockContentGenerator, mockConfig); + client = new BaseLlmClient(mockContentGenerator, mockConfig); abortController = new AbortController(); defaultOptions = { contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }], @@ -92,7 +89,7 @@ describe('LlmUtilityService', () => { const mockResponse = createMockResponse('{"color": "blue"}'); mockGenerateContent.mockResolvedValue(mockResponse); - const result = await service.generateJson(defaultOptions); + const result = await client.generateJson(defaultOptions); expect(result).toEqual({ color: 'blue' }); @@ -127,7 +124,7 @@ describe('LlmUtilityService', () => { config: { temperature: 0.8, topK: 10 }, }; - await service.generateJson(options); + await client.generateJson(options); expect(mockGenerateContent).toHaveBeenCalledWith( expect.objectContaining({ @@ -151,7 +148,7 @@ describe('LlmUtilityService', () => { systemInstruction, }; - await service.generateJson(options); + await client.generateJson(options); expect(mockGenerateContent).toHaveBeenCalledWith( expect.objectContaining({ @@ -173,7 +170,7 @@ describe('LlmUtilityService', () => { promptId: customPromptId, }; - await service.generateJson(options); + await client.generateJson(options); expect(mockGenerateContent).toHaveBeenCalledWith( expect.any(Object), @@ -189,7 +186,7 @@ describe('LlmUtilityService', () => { createMockResponse(malformedResponse), ); - const result = await service.generateJson(defaultOptions); + const result = await client.generateJson(defaultOptions); expect(result).toEqual({ color: 'purple' }); expect(logMalformedJsonResponse).toHaveBeenCalledTimes(1); @@ -209,7 +206,7 @@ describe('LlmUtilityService', () => { createMockResponse(responseWithWhitespace), ); - const result = await service.generateJson(defaultOptions); + const result = await client.generateJson(defaultOptions); expect(result).toEqual({ color: 'orange' }); expect(logMalformedJsonResponse).not.toHaveBeenCalled(); @@ -220,8 +217,8 @@ describe('LlmUtilityService', () => { it('should throw and report error for empty response', async () => { mockGenerateContent.mockResolvedValue(createMockResponse('')); - // The final error message includes the prefix added by the service's outer catch block. - await expect(service.generateJson(defaultOptions)).rejects.toThrow( + // The final error message includes the prefix added by the client's outer catch block. + await expect(client.generateJson(defaultOptions)).rejects.toThrow( 'Failed to generate JSON content: API returned an empty response for generateJson.', ); @@ -239,7 +236,7 @@ describe('LlmUtilityService', () => { const invalidJson = '{"color": "blue"'; // missing closing brace mockGenerateContent.mockResolvedValue(createMockResponse(invalidJson)); - await expect(service.generateJson(defaultOptions)).rejects.toThrow( + await expect(client.generateJson(defaultOptions)).rejects.toThrow( /^Failed to generate JSON content: Failed to parse API response as JSON:/, ); @@ -257,7 +254,7 @@ describe('LlmUtilityService', () => { // Simulate the generator failing mockGenerateContent.mockRejectedValue(apiError); - await expect(service.generateJson(defaultOptions)).rejects.toThrow( + await expect(client.generateJson(defaultOptions)).rejects.toThrow( 'Failed to generate JSON content: Service Unavailable (503)', ); @@ -285,7 +282,7 @@ describe('LlmUtilityService', () => { abortSignal: abortController.signal, }; - await expect(service.generateJson(options)).rejects.toThrow(abortError); + await expect(client.generateJson(options)).rejects.toThrow(abortError); // Crucially, it should not report a cancellation as an application error expect(reportError).not.toHaveBeenCalled(); diff --git a/packages/core/src/core/llmUtilityService.ts b/packages/core/src/core/baseLlmClient.ts similarity index 97% rename from packages/core/src/core/llmUtilityService.ts rename to packages/core/src/core/baseLlmClient.ts index 96fdeb5b55c..25a92dabdd7 100644 --- a/packages/core/src/core/llmUtilityService.ts +++ b/packages/core/src/core/baseLlmClient.ts @@ -49,9 +49,9 @@ export interface GenerateJsonOptions { } /** - * A service dedicated to stateless, utility-focused LLM calls. + * A client dedicated to stateless, utility-focused LLM calls. */ -export class LlmUtilityService { +export class BaseLlmClient { // Default configuration for utility tasks private readonly defaultUtilityConfig: GenerateContentConfig = { temperature: 0, diff --git a/packages/core/src/tools/smart-edit.test.ts b/packages/core/src/tools/smart-edit.test.ts index 6d1fb3d93f5..c72fcb48df8 100644 --- a/packages/core/src/tools/smart-edit.test.ts +++ b/packages/core/src/tools/smart-edit.test.ts @@ -60,7 +60,7 @@ import { ApprovalMode, type Config } from '../config/config.js'; import { type Content, type Part, type SchemaUnion } from '@google/genai'; import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js'; import { StandardFileSystemService } from '../services/fileSystemService.js'; -import type { LlmUtilityService } from '../core/llmUtilityService.js'; +import type { BaseLlmClient } from '../core/baseLlmClient.js'; describe('SmartEditTool', () => { let tool: SmartEditTool; @@ -68,7 +68,7 @@ describe('SmartEditTool', () => { let rootDir: string; let mockConfig: Config; let geminiClient: any; - let llmUtilityService: LlmUtilityService; + let baseLlmClient: BaseLlmClient; beforeEach(() => { vi.restoreAllMocks(); @@ -80,13 +80,13 @@ describe('SmartEditTool', () => { generateJson: mockGenerateJson, }; - llmUtilityService = { + baseLlmClient = { generateJson: mockGenerateJson, - } as unknown as LlmUtilityService; + } as unknown as BaseLlmClient; mockConfig = { getGeminiClient: vi.fn().mockReturnValue(geminiClient), - getLlmUtilityService: vi.fn().mockReturnValue(llmUtilityService), + getBaseLlmClient: vi.fn().mockReturnValue(baseLlmClient), getTargetDir: () => rootDir, getApprovalMode: vi.fn(), setApprovalMode: vi.fn(), diff --git a/packages/core/src/tools/smart-edit.ts b/packages/core/src/tools/smart-edit.ts index 920222ba452..6291296f132 100644 --- a/packages/core/src/tools/smart-edit.ts +++ b/packages/core/src/tools/smart-edit.ts @@ -310,7 +310,7 @@ class EditToolInvocation implements ToolInvocation { params.new_string, initialError.raw, currentContent, - this.config.getLlmUtilityService(), + this.config.getBaseLlmClient(), abortSignal, ); diff --git a/packages/core/src/utils/llm-edit-fixer.test.ts b/packages/core/src/utils/llm-edit-fixer.test.ts index eaf6c9ba968..4c236ad3425 100644 --- a/packages/core/src/utils/llm-edit-fixer.test.ts +++ b/packages/core/src/utils/llm-edit-fixer.test.ts @@ -11,13 +11,13 @@ import { type SearchReplaceEdit, } from './llm-edit-fixer.js'; import { promptIdContext } from './promptIdContext.js'; -import type { LlmUtilityService } from '../core/llmUtilityService.js'; +import type { BaseLlmClient } from '../core/baseLlmClient.js'; -// Mock the LlmUtilityService +// Mock the BaseLlmClient const mockGenerateJson = vi.fn(); -const mockLlmUtilityService = { +const mockBaseLlmClient = { generateJson: mockGenerateJson, -} as unknown as LlmUtilityService; +} as unknown as BaseLlmClient; describe('FixLLMEditWithInstruction', () => { const instruction = 'Replace the title'; @@ -48,15 +48,14 @@ describe('FixLLMEditWithInstruction', () => { const testPromptId = 'test-prompt-id-12345'; mockGenerateJson.mockResolvedValue(mockApiResponse); - // Run the function within the context - await promptIdContext.run({ promptId: testPromptId }, async () => { + await promptIdContext.run(testPromptId, async () => { await FixLLMEditWithInstruction( instruction, old_string, new_string, error, current_content, - mockLlmUtilityService, + mockBaseLlmClient, abortSignal, ); }); @@ -83,7 +82,7 @@ describe('FixLLMEditWithInstruction', () => { new_string, error, current_content, - mockLlmUtilityService, + mockBaseLlmClient, abortSignal, ); @@ -108,16 +107,16 @@ describe('FixLLMEditWithInstruction', () => { it('should construct the user prompt correctly', async () => { mockGenerateJson.mockResolvedValue(mockApiResponse); - const testPromptId = 'test-prompt-id-prompt-construction'; + const promptId = 'test-prompt-id-prompt-construction'; - await promptIdContext.run({ promptId: testPromptId }, async () => { + await promptIdContext.run(promptId, async () => { await FixLLMEditWithInstruction( instruction, old_string, new_string, error, current_content, - mockLlmUtilityService, + mockBaseLlmClient, abortSignal, ); }); @@ -140,7 +139,7 @@ describe('FixLLMEditWithInstruction', () => { mockGenerateJson.mockResolvedValue(mockApiResponse); const testPromptId = 'test-prompt-id-caching'; - await promptIdContext.run({ promptId: testPromptId }, async () => { + await promptIdContext.run(testPromptId, async () => { // First call - should call the API const result1 = await FixLLMEditWithInstruction( instruction, @@ -148,7 +147,7 @@ describe('FixLLMEditWithInstruction', () => { new_string, error, current_content, - mockLlmUtilityService, + mockBaseLlmClient, abortSignal, ); @@ -159,7 +158,7 @@ describe('FixLLMEditWithInstruction', () => { new_string, error, current_content, - mockLlmUtilityService, + mockBaseLlmClient, abortSignal, ); @@ -174,7 +173,7 @@ describe('FixLLMEditWithInstruction', () => { mockGenerateJson.mockResolvedValue(mockApiResponse); const testPromptId = 'test-prompt-id-cache-miss'; - await promptIdContext.run({ promptId: testPromptId }, async () => { + await promptIdContext.run(testPromptId, async () => { // First call await FixLLMEditWithInstruction( instruction, @@ -182,7 +181,7 @@ describe('FixLLMEditWithInstruction', () => { new_string, error, current_content, - mockLlmUtilityService, + mockBaseLlmClient, abortSignal, ); @@ -193,7 +192,7 @@ describe('FixLLMEditWithInstruction', () => { new_string, error, current_content, - mockLlmUtilityService, + mockBaseLlmClient, abortSignal, ); diff --git a/packages/core/src/utils/llm-edit-fixer.ts b/packages/core/src/utils/llm-edit-fixer.ts index 861c6579f0f..a4b4b131c0c 100644 --- a/packages/core/src/utils/llm-edit-fixer.ts +++ b/packages/core/src/utils/llm-edit-fixer.ts @@ -5,7 +5,7 @@ */ import { type Content, Type } from '@google/genai'; -import { type LlmUtilityService } from '../core/llmUtilityService.js'; +import { type BaseLlmClient } from '../core/baseLlmClient.js'; import { LruCache } from './LruCache.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { promptIdContext } from './promptIdContext.js'; @@ -94,7 +94,7 @@ const editCorrectionWithInstructionCache = new LruCache< * @param new_string The original replacement string. * @param error The error that occurred during the initial edit. * @param current_content The current content of the file. - * @param llmUtilityService The LlmUtilityService to use for the LLM call. + * @param baseLlmClient The BaseLlmClient to use for the LLM call. * @param abortSignal An abort signal to cancel the operation. * @param promptId A unique ID for the prompt. * @returns A new search and replace pair. @@ -105,18 +105,15 @@ export async function FixLLMEditWithInstruction( new_string: string, error: string, current_content: string, - llmUtilityService: LlmUtilityService, + baseLlmClient: BaseLlmClient, abortSignal: AbortSignal, ): Promise { - let promptId: string; - const context = promptIdContext.getStore(); - if (!context) { + let promptId = promptIdContext.getStore(); + if (!promptId) { promptId = `llm-fixer-fallback-${Date.now()}-${Math.random().toString(16).slice(2)}`; console.warn( `Could not find promptId in context. This is unexpected. Using a fallback ID: ${promptId}`, ); - } else { - promptId = context.promptId; } const cacheKey = `${instruction}---${old_string}---${new_string}--${current_content}--${error}`; @@ -137,7 +134,7 @@ export async function FixLLMEditWithInstruction( }, ]; - const result = (await llmUtilityService.generateJson({ + const result = (await baseLlmClient.generateJson({ contents, schema: SearchReplaceEditSchema, abortSignal, diff --git a/packages/core/src/utils/promptIdContext.ts b/packages/core/src/utils/promptIdContext.ts index 22f501904ff..6344bd0b834 100644 --- a/packages/core/src/utils/promptIdContext.ts +++ b/packages/core/src/utils/promptIdContext.ts @@ -6,8 +6,4 @@ import { AsyncLocalStorage } from 'node:async_hooks'; -interface PromptContext { - promptId: string; -} - -export const promptIdContext = new AsyncLocalStorage(); +export const promptIdContext = new AsyncLocalStorage();