Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion packages/a2a-server/src/agent/executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ export class CoderAgentExecutor implements AgentExecutor {
contextId,
config,
eventBus,
agentSettings.autoExecute,
);
runtimeTask.taskState = persistedState._taskState;
await runtimeTask.geminiClient.initialize();
Expand All @@ -145,7 +146,13 @@ export class CoderAgentExecutor implements AgentExecutor {
): Promise<TaskWrapper> {
const agentSettings = agentSettingsInput || ({} as AgentSettings);
const config = await this.getConfig(agentSettings, taskId);
const runtimeTask = await Task.create(taskId, contextId, config, eventBus);
const runtimeTask = await Task.create(
taskId,
contextId,
config,
eventBus,
agentSettings.autoExecute,
);
await runtimeTask.geminiClient.initialize();

const wrapper = new TaskWrapper(runtimeTask, agentSettings);
Expand Down
69 changes: 67 additions & 2 deletions packages/a2a-server/src/agent/task.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import {
type ToolCallRequestInfo,
type GitService,
type CompletedToolCall,
ApprovalMode,
ToolConfirmationOutcome,
} from '@google/gemini-cli-core';
import { createMockConfig } from '../utils/testing_utils.js';
import type { ExecutionEventBus, RequestContext } from '@a2a-js/sdk/server';
Expand Down Expand Up @@ -353,10 +355,12 @@ describe('Task', () => {
let task: Task;
type SpyInstance = ReturnType<typeof vi.spyOn>;
let setTaskStateAndPublishUpdateSpy: SpyInstance;
let mockConfig: Config;
let mockEventBus: ExecutionEventBus;

beforeEach(() => {
const mockConfig = createMockConfig();
const mockEventBus: ExecutionEventBus = {
mockConfig = createMockConfig() as Config;
mockEventBus = {
publish: vi.fn(),
on: vi.fn(),
off: vi.fn(),
Expand Down Expand Up @@ -465,6 +469,67 @@ describe('Task', () => {
);
expect(finalCall).toBeUndefined();
});

describe('auto-approval', () => {
it('should auto-approve tool calls when autoExecute is true', () => {
task.autoExecute = true;
const onConfirmSpy = vi.fn();
const toolCalls = [
{
request: { callId: '1' },
status: 'awaiting_approval',
confirmationDetails: { onConfirm: onConfirmSpy },
},
] as unknown as ToolCall[];

// @ts-expect-error - Calling private method
task._schedulerToolCallsUpdate(toolCalls);

expect(onConfirmSpy).toHaveBeenCalledWith(
ToolConfirmationOutcome.ProceedOnce,
);
});

it('should auto-approve tool calls when approval mode is YOLO', () => {
(mockConfig.getApprovalMode as Mock).mockReturnValue(ApprovalMode.YOLO);
Comment thread
cocosheng-g marked this conversation as resolved.
task.autoExecute = false;
const onConfirmSpy = vi.fn();
const toolCalls = [
{
request: { callId: '1' },
status: 'awaiting_approval',
confirmationDetails: { onConfirm: onConfirmSpy },
},
] as unknown as ToolCall[];

// @ts-expect-error - Calling private method
task._schedulerToolCallsUpdate(toolCalls);

expect(onConfirmSpy).toHaveBeenCalledWith(
ToolConfirmationOutcome.ProceedOnce,
);
});

it('should NOT auto-approve when autoExecute is false and mode is not YOLO', () => {
task.autoExecute = false;
(mockConfig.getApprovalMode as Mock).mockReturnValue(
ApprovalMode.DEFAULT,
);
const onConfirmSpy = vi.fn();
const toolCalls = [
{
request: { callId: '1' },
status: 'awaiting_approval',
confirmationDetails: { onConfirm: onConfirmSpy },
},
] as unknown as ToolCall[];

// @ts-expect-error - Calling private method
task._schedulerToolCallsUpdate(toolCalls);

expect(onConfirmSpy).not.toHaveBeenCalled();
});
});
});

describe('currentPromptId and promptCount', () => {
Expand Down
17 changes: 14 additions & 3 deletions packages/a2a-server/src/agent/task.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ export class Task {
modelInfo?: string;
currentPromptId: string | undefined;
promptCount = 0;
autoExecute: boolean;

// For tool waiting logic
private pendingToolCalls: Map<string, string> = new Map(); //toolCallId --> status
Expand All @@ -87,6 +88,7 @@ export class Task {
contextId: string,
config: Config,
eventBus?: ExecutionEventBus,
autoExecute = false,
) {
this.id = id;
this.contextId = contextId;
Expand All @@ -98,6 +100,7 @@ export class Task {
this.eventBus = eventBus;
this.completedToolCalls = [];
this._resetToolCompletionPromise();
this.autoExecute = autoExecute;
this.config.setFallbackModelHandler(
// For a2a-server, we want to automatically switch to the fallback model
// for future requests without retrying the current one. The 'stop'
Expand All @@ -111,8 +114,9 @@ export class Task {
contextId: string,
config: Config,
eventBus?: ExecutionEventBus,
autoExecute?: boolean,
): Promise<Task> {
return new Task(id, contextId, config, eventBus);
return new Task(id, contextId, config, eventBus, autoExecute);
}

// Note: `getAllMCPServerStatuses` retrieves the status of all MCP servers for the entire
Expand Down Expand Up @@ -396,8 +400,15 @@ export class Task {
}
});

if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
logger.info('[Task] YOLO mode enabled. Auto-approving all tool calls.');
if (
this.autoExecute ||
this.config.getApprovalMode() === ApprovalMode.YOLO
) {
logger.info(
'[Task] ' +
(this.autoExecute ? '' : 'YOLO mode enabled. ') +
'Auto-approving all tool calls.',
);
toolCalls.forEach((tc: ToolCall) => {
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
Expand Down
2 changes: 2 additions & 0 deletions packages/a2a-server/src/commands/command-registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/

import { ExtensionsCommand } from './extensions.js';
import { InitCommand } from './init.js';
import { RestoreCommand } from './restore.js';
import type { Command } from './types.js';

Expand All @@ -14,6 +15,7 @@ class CommandRegistry {
constructor() {
this.register(new ExtensionsCommand());
this.register(new RestoreCommand());
this.register(new InitCommand());
}

register(command: Command) {
Expand Down
182 changes: 182 additions & 0 deletions packages/a2a-server/src/commands/init.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/

import { describe, it, expect, vi, beforeEach } from 'vitest';
import { InitCommand } from './init.js';
import { performInit } from '@google/gemini-cli-core';
import * as fs from 'node:fs';
import * as path from 'node:path';
import { CoderAgentExecutor } from '../agent/executor.js';
import { CoderAgentEvent } from '../types.js';
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
import { createMockConfig } from '../utils/testing_utils.js';
import type { CommandContext } from './types.js';
import type { CommandActionReturn, Config } from '@google/gemini-cli-core';
import { logger } from '../utils/logger.js';

vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const actual =
await importOriginal<typeof import('@google/gemini-cli-core')>();
return {
...actual,
performInit: vi.fn(),
};
});

vi.mock('node:fs', () => ({
existsSync: vi.fn(),
writeFileSync: vi.fn(),
}));

vi.mock('../agent/executor.js', () => ({
CoderAgentExecutor: vi.fn().mockImplementation(() => ({
execute: vi.fn(),
})),
}));

vi.mock('../utils/logger.js', () => ({
logger: {
info: vi.fn(),
error: vi.fn(),
},
}));

describe('InitCommand', () => {
let eventBus: ExecutionEventBus;
let command: InitCommand;
let context: CommandContext;
let publishSpy: ReturnType<typeof vi.spyOn>;
let mockExecute: ReturnType<typeof vi.fn>;
const mockWorkspacePath = path.resolve('/tmp');

beforeEach(() => {
process.env['CODER_AGENT_WORKSPACE_PATH'] = mockWorkspacePath;
eventBus = {
publish: vi.fn(),
} as unknown as ExecutionEventBus;
command = new InitCommand();
const mockConfig = createMockConfig({
getModel: () => 'gemini-pro',
});
const mockExecutorInstance = new CoderAgentExecutor();
context = {
config: mockConfig as unknown as Config,
agentExecutor: mockExecutorInstance,
eventBus,
} as CommandContext;
publishSpy = vi.spyOn(eventBus, 'publish');
mockExecute = vi.fn();
vi.spyOn(mockExecutorInstance, 'execute').mockImplementation(mockExecute);
vi.clearAllMocks();
});

it('has requiresWorkspace set to true', () => {
expect(command.requiresWorkspace).toBe(true);
});

describe('execute', () => {
it('handles info from performInit', async () => {
vi.mocked(performInit).mockReturnValue({
type: 'message',
messageType: 'info',
content: 'GEMINI.md already exists.',
} as CommandActionReturn);

await command.execute(context, []);

expect(logger.info).toHaveBeenCalledWith(
'[EventBus event]: ',
expect.objectContaining({
kind: 'status-update',
status: expect.objectContaining({
state: 'completed',
message: expect.objectContaining({
parts: [{ kind: 'text', text: 'GEMINI.md already exists.' }],
}),
}),
}),
);

expect(publishSpy).toHaveBeenCalledWith(
expect.objectContaining({
kind: 'status-update',
status: expect.objectContaining({
state: 'completed',
message: expect.objectContaining({
parts: [{ kind: 'text', text: 'GEMINI.md already exists.' }],
}),
}),
}),
);
});

it('handles error from performInit', async () => {
vi.mocked(performInit).mockReturnValue({
type: 'message',
messageType: 'error',
content: 'An error occurred.',
} as CommandActionReturn);

await command.execute(context, []);

expect(publishSpy).toHaveBeenCalledWith(
expect.objectContaining({
kind: 'status-update',
status: expect.objectContaining({
state: 'failed',
message: expect.objectContaining({
parts: [{ kind: 'text', text: 'An error occurred.' }],
}),
}),
}),
);
});

describe('when handling submit_prompt', () => {
beforeEach(() => {
vi.mocked(performInit).mockReturnValue({
type: 'submit_prompt',
content: 'Create a new GEMINI.md file.',
} as CommandActionReturn);
});

it('writes the file and executes the agent', async () => {
await command.execute(context, []);

expect(fs.writeFileSync).toHaveBeenCalledWith(
path.join(mockWorkspacePath, 'GEMINI.md'),
'',
'utf8',
);
expect(mockExecute).toHaveBeenCalled();
});

it('passes autoExecute to the agent executor', async () => {
await command.execute(context, []);

expect(mockExecute).toHaveBeenCalledWith(
expect.objectContaining({
userMessage: expect.objectContaining({
parts: expect.arrayContaining([
expect.objectContaining({
text: 'Create a new GEMINI.md file.',
}),
]),
metadata: {
coderAgent: {
kind: CoderAgentEvent.StateAgentSettingsEvent,
workspacePath: mockWorkspacePath,
autoExecute: true,
},
},
}),
}),
eventBus,
);
});
});
});
});
Loading