Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions packages/core/src/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ import { WorkspaceContext } from '../utils/workspaceContext.js';
import { Storage } from './storage.js';
import { FileExclusions } from '../utils/ignorePatterns.js';
import type { EventEmitter } from 'node:events';
import { MessageBus } from '../confirmation-bus/message-bus.js';
import { PolicyEngine } from '../policy/policy-engine.js';
import type { PolicyEngineConfig } from '../policy/types.js';
import type { UserTierId } from '../code_assist/types.js';
import { ProxyAgent, setGlobalDispatcher } from 'undici';

Expand Down Expand Up @@ -228,6 +231,7 @@ export interface ConfigParameters {
enableToolOutputTruncation?: boolean;
eventEmitter?: EventEmitter;
useSmartEdit?: boolean;
policyEngineConfig?: PolicyEngineConfig;
}

export class Config {
Expand Down Expand Up @@ -310,6 +314,8 @@ export class Config {
private readonly fileExclusions: FileExclusions;
private readonly eventEmitter?: EventEmitter;
private readonly useSmartEdit: boolean;
private readonly messageBus: MessageBus;
private readonly policyEngine: PolicyEngine;

constructor(params: ConfigParameters) {
this.sessionId = params.sessionId;
Expand Down Expand Up @@ -393,6 +399,8 @@ export class Config {
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
this.fileExclusions = new FileExclusions(this);
this.eventEmitter = params.eventEmitter;
this.policyEngine = new PolicyEngine(params.policyEngineConfig);
this.messageBus = new MessageBus(this.policyEngine);

if (params.contextFileName) {
setGeminiMdFilename(params.contextFileName);
Expand Down Expand Up @@ -892,6 +900,14 @@ export class Config {
return this.fileExclusions;
}

getMessageBus(): MessageBus {
return this.messageBus;
}

getPolicyEngine(): PolicyEngine {
return this.policyEngine;
}

async createToolRegistry(): Promise<ToolRegistry> {
const registry = new ToolRegistry(this, this.eventEmitter);

Expand Down
8 changes: 8 additions & 0 deletions packages/core/src/confirmation-bus/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/

export * from './message-bus.js';
export * from './types.js';
235 changes: 235 additions & 0 deletions packages/core/src/confirmation-bus/message-bus.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/

import { describe, it, expect, beforeEach, vi } from 'vitest';
import { MessageBus } from './message-bus.js';
import { PolicyEngine } from '../policy/policy-engine.js';
import { PolicyDecision } from '../policy/types.js';
import {
MessageBusType,
type ToolConfirmationRequest,
type ToolConfirmationResponse,
type ToolPolicyRejection,
type ToolExecutionSuccess,
} from './types.js';

describe('MessageBus', () => {
let messageBus: MessageBus;
let policyEngine: PolicyEngine;

beforeEach(() => {
policyEngine = new PolicyEngine();
messageBus = new MessageBus(policyEngine);
});

describe('publish', () => {
it('should emit error for invalid message', () => {
const errorHandler = vi.fn();
messageBus.on('error', errorHandler);

// @ts-expect-error - Testing invalid message
messageBus.publish({ invalid: 'message' });

expect(errorHandler).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.stringContaining('Invalid message structure'),
}),
);
});

it('should validate tool confirmation requests have correlationId', () => {
const errorHandler = vi.fn();
messageBus.on('error', errorHandler);

// @ts-expect-error - Testing missing correlationId
messageBus.publish({
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
toolCall: { name: 'test' },
});

expect(errorHandler).toHaveBeenCalled();
});

it('should emit confirmation response when policy allows', () => {
vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.ALLOW);

const responseHandler = vi.fn();
messageBus.subscribe(
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
responseHandler,
);

const request: ToolConfirmationRequest = {
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
toolCall: { name: 'test-tool', args: {} },
correlationId: '123',
};

messageBus.publish(request);

const expectedResponse: ToolConfirmationResponse = {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: '123',
confirmed: true,
};
expect(responseHandler).toHaveBeenCalledWith(expectedResponse);
});

it('should emit rejection and response when policy denies', () => {
vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.DENY);

const responseHandler = vi.fn();
const rejectionHandler = vi.fn();
messageBus.subscribe(
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
responseHandler,
);
messageBus.subscribe(
MessageBusType.TOOL_POLICY_REJECTION,
rejectionHandler,
);

const request: ToolConfirmationRequest = {
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
toolCall: { name: 'test-tool', args: {} },
correlationId: '123',
};

messageBus.publish(request);

const expectedRejection: ToolPolicyRejection = {
type: MessageBusType.TOOL_POLICY_REJECTION,
toolCall: { name: 'test-tool', args: {} },
};
expect(rejectionHandler).toHaveBeenCalledWith(expectedRejection);

const expectedResponse: ToolConfirmationResponse = {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: '123',
confirmed: false,
};
expect(responseHandler).toHaveBeenCalledWith(expectedResponse);
});

it('should pass through to UI when policy says ASK_USER', () => {
vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.ASK_USER);

const requestHandler = vi.fn();
messageBus.subscribe(
MessageBusType.TOOL_CONFIRMATION_REQUEST,
requestHandler,
);

const request: ToolConfirmationRequest = {
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
toolCall: { name: 'test-tool', args: {} },
correlationId: '123',
};

messageBus.publish(request);

expect(requestHandler).toHaveBeenCalledWith(request);
});

it('should emit other message types directly', () => {
const successHandler = vi.fn();
messageBus.subscribe(
MessageBusType.TOOL_EXECUTION_SUCCESS,
successHandler,
);

const message: ToolExecutionSuccess<string> = {
type: MessageBusType.TOOL_EXECUTION_SUCCESS as const,
toolCall: { name: 'test-tool' },
result: 'success',
};

messageBus.publish(message);

expect(successHandler).toHaveBeenCalledWith(message);
});
});

describe('subscribe/unsubscribe', () => {
it('should allow subscribing to specific message types', () => {
const handler = vi.fn();
messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler);

const message: ToolExecutionSuccess<string> = {
type: MessageBusType.TOOL_EXECUTION_SUCCESS as const,
toolCall: { name: 'test' },
result: 'test',
};

messageBus.publish(message);

expect(handler).toHaveBeenCalledWith(message);
});

it('should allow unsubscribing from message types', () => {
const handler = vi.fn();
messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler);
messageBus.unsubscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler);

const message: ToolExecutionSuccess<string> = {
type: MessageBusType.TOOL_EXECUTION_SUCCESS as const,
toolCall: { name: 'test' },
result: 'test',
};

messageBus.publish(message);

expect(handler).not.toHaveBeenCalled();
});

it('should support multiple subscribers for the same message type', () => {
const handler1 = vi.fn();
const handler2 = vi.fn();

messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler1);
messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler2);

const message: ToolExecutionSuccess<string> = {
type: MessageBusType.TOOL_EXECUTION_SUCCESS as const,
toolCall: { name: 'test' },
result: 'test',
};

messageBus.publish(message);

expect(handler1).toHaveBeenCalledWith(message);
expect(handler2).toHaveBeenCalledWith(message);
});
});

describe('error handling', () => {
it('should not crash on errors during message processing', () => {
const errorHandler = vi.fn();
messageBus.on('error', errorHandler);

// Mock policyEngine to throw an error
vi.spyOn(policyEngine, 'check').mockImplementation(() => {
throw new Error('Policy check failed');
});

const request: ToolConfirmationRequest = {
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
toolCall: { name: 'test-tool' },
correlationId: '123',
};

// Should not throw
expect(() => messageBus.publish(request)).not.toThrow();

// Should emit error
expect(errorHandler).toHaveBeenCalledWith(
expect.objectContaining({
message: 'Policy check failed',
}),
);
});
});
});
98 changes: 98 additions & 0 deletions packages/core/src/confirmation-bus/message-bus.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/

import { EventEmitter } from 'node:events';
import type { PolicyEngine } from '../policy/policy-engine.js';
import { PolicyDecision } from '../policy/types.js';
import { MessageBusType, type Message } from './types.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';

export class MessageBus extends EventEmitter {
constructor(private readonly policyEngine: PolicyEngine) {
super();
}

private isValidMessage(message: Message): boolean {
if (!message || !message.type) {
return false;
}

if (
message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST &&
!('correlationId' in message)
) {
return false;
}

return true;
}

private emitMessage(message: Message): void {
this.emit(message.type, message);
}

publish(message: Message): void {
try {
if (!this.isValidMessage(message)) {
throw new Error(
`Invalid message structure: ${safeJsonStringify(message)}`,
);
}

if (message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
const decision = this.policyEngine.check(message.toolCall);

switch (decision) {
case PolicyDecision.ALLOW:
// Directly emit the response instead of recursive publish
this.emitMessage({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: message.correlationId,
confirmed: true,
});
break;
case PolicyDecision.DENY:
// Emit both rejection and response messages
this.emitMessage({
type: MessageBusType.TOOL_POLICY_REJECTION,
toolCall: message.toolCall,
});
this.emitMessage({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: message.correlationId,
confirmed: false,
});
break;
case PolicyDecision.ASK_USER:
// Pass through to UI for user confirmation
this.emitMessage(message);
break;
default:
throw new Error(`Unknown policy decision: ${decision}`);
}
} else {
// For all other message types, just emit them
this.emitMessage(message);
}
} catch (error) {
this.emit('error', error);
}
}

subscribe<T extends Message>(
type: T['type'],
listener: (message: T) => void,
): void {
this.on(type, listener);
}

unsubscribe<T extends Message>(
type: T['type'],
listener: (message: T) => void,
): void {
this.off(type, listener);
}
}
Loading