From d06300f00784175ffd0a987275efb8ed960db412 Mon Sep 17 00:00:00 2001 From: matt korwel Date: Tue, 10 Feb 2026 08:25:21 -0600 Subject: [PATCH 1/3] feat(routing): restrict numerical routing to Gemini 3 family (#18478) # Conflicts: # packages/core/src/routing/strategies/classifierStrategy.test.ts # packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts # packages/core/src/routing/strategies/numericalClassifierStrategy.ts --- packages/core/src/config/models.test.ts | 24 ++++++++ packages/core/src/config/models.ts | 11 ++++ .../strategies/classifierStrategy.test.ts | 26 ++++++++- .../routing/strategies/classifierStrategy.ts | 10 +++- .../numericalClassifierStrategy.test.ts | 56 +++++++++++++++---- .../strategies/numericalClassifierStrategy.ts | 11 +++- 6 files changed, 121 insertions(+), 17 deletions(-) diff --git a/packages/core/src/config/models.test.ts b/packages/core/src/config/models.test.ts index 8e6c3ea8954..73865accd2e 100644 --- a/packages/core/src/config/models.test.ts +++ b/packages/core/src/config/models.test.ts @@ -8,6 +8,7 @@ import { describe, it, expect } from 'vitest'; import { resolveModel, resolveClassifierModel, + isGemini3Model, isGemini2Model, isAutoModel, getDisplayString, @@ -25,6 +26,29 @@ import { DEFAULT_GEMINI_MODEL_AUTO, } from './models.js'; +describe('isGemini3Model', () => { + it('should return true for gemini-3 models', () => { + expect(isGemini3Model('gemini-3-pro-preview')).toBe(true); + expect(isGemini3Model('gemini-3-flash-preview')).toBe(true); + }); + + it('should return true for aliases that resolve to Gemini 3', () => { + expect(isGemini3Model(GEMINI_MODEL_ALIAS_AUTO)).toBe(true); + expect(isGemini3Model(GEMINI_MODEL_ALIAS_PRO)).toBe(true); + expect(isGemini3Model(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true); + }); + + it('should return false for Gemini 2 models', () => { + expect(isGemini3Model('gemini-2.5-pro')).toBe(false); + expect(isGemini3Model('gemini-2.5-flash')).toBe(false); + expect(isGemini3Model(DEFAULT_GEMINI_MODEL_AUTO)).toBe(false); + }); + + it('should return false for arbitrary strings', () => { + expect(isGemini3Model('gpt-4')).toBe(false); + }); +}); + describe('getDisplayString', () => { it('should return Auto (Gemini 3) for preview auto model', () => { expect(getDisplayString(PREVIEW_GEMINI_MODEL_AUTO)).toBe('Auto (Gemini 3)'); diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index 519f49c98ec..24d747b9696 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -137,6 +137,17 @@ export function isPreviewModel(model: string): boolean { ); } +/** + * Checks if the model is a Gemini 3 model. + * + * @param model The model name to check. + * @returns True if the model is a Gemini 3 model. + */ +export function isGemini3Model(model: string): boolean { + const resolved = resolveModel(model); + return /^gemini-3(\.|-|$)/.test(resolved); +} + /** * Checks if the model is a Gemini 2.x model. * diff --git a/packages/core/src/routing/strategies/classifierStrategy.test.ts b/packages/core/src/routing/strategies/classifierStrategy.test.ts index ef0f784ee20..9b6c3e35d01 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.test.ts @@ -17,6 +17,7 @@ import { DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL_AUTO, + PREVIEW_GEMINI_MODEL_AUTO, } from '../../config/models.js'; import { promptIdContext } from '../../utils/promptIdContext.js'; import type { Content } from '@google/genai'; @@ -50,8 +51,12 @@ describe('ClassifierStrategy', () => { modelConfigService: { getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig), }, +<<<<<<< HEAD getModel: () => DEFAULT_GEMINI_MODEL_AUTO, getPreviewFeatures: () => false, +======= + getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO), +>>>>>>> 37f128a10 (feat(routing): restrict numerical routing to Gemini 3 family (#18478)) getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false), } as unknown as Config; mockBaseLlmClient = { @@ -61,8 +66,9 @@ describe('ClassifierStrategy', () => { vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id'); }); - it('should return null if numerical routing is enabled', async () => { + it('should return null if numerical routing is enabled and model is Gemini 3', async () => { vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true); + vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO); const decision = await strategy.route( mockContext, @@ -74,6 +80,24 @@ describe('ClassifierStrategy', () => { expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); }); + it('should NOT return null if numerical routing is enabled but model is NOT Gemini 3', async () => { + vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true); + vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO); + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue({ + reasoning: 'test', + model_choice: 'flash', + }); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision).not.toBeNull(); + expect(mockBaseLlmClient.generateJson).toHaveBeenCalled(); + }); + it('should call generateJson with the correct parameters', async () => { const mockApiResponse = { reasoning: 'Simple task', diff --git a/packages/core/src/routing/strategies/classifierStrategy.ts b/packages/core/src/routing/strategies/classifierStrategy.ts index 4edf85a3515..8396f4aceb1 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.ts @@ -12,7 +12,7 @@ import type { RoutingDecision, RoutingStrategy, } from '../routingStrategy.js'; -import { resolveClassifierModel } from '../../config/models.js'; +import { resolveClassifierModel, isGemini3Model } from '../../config/models.js'; import { createUserContent, Type } from '@google/genai'; import type { Config } from '../../config/config.js'; import { @@ -133,7 +133,11 @@ export class ClassifierStrategy implements RoutingStrategy { ): Promise { const startTime = Date.now(); try { - if (await config.getNumericalRoutingEnabled()) { + const model = context.requestedModel ?? config.getModel(); + if ( + (await config.getNumericalRoutingEnabled()) && + isGemini3Model(model) + ) { return null; } @@ -164,7 +168,7 @@ export class ClassifierStrategy implements RoutingStrategy { const reasoning = routerResponse.reasoning; const latencyMs = Date.now() - startTime; const selectedModel = resolveClassifierModel( - context.requestedModel ?? config.getModel(), + model, routerResponse.model_choice, config.getPreviewFeatures(), ); diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts index 93e75fcdb5b..facae8a1863 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts @@ -10,9 +10,11 @@ import type { RoutingContext } from '../routingStrategy.js'; import type { Config } from '../../config/config.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import { - DEFAULT_GEMINI_FLASH_MODEL, - DEFAULT_GEMINI_MODEL, + PREVIEW_GEMINI_FLASH_MODEL, + PREVIEW_GEMINI_MODEL, + PREVIEW_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO, + DEFAULT_GEMINI_MODEL, } from '../../config/models.js'; import { promptIdContext } from '../../utils/promptIdContext.js'; import type { Content } from '@google/genai'; @@ -46,8 +48,12 @@ describe('NumericalClassifierStrategy', () => { modelConfigService: { getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig), }, +<<<<<<< HEAD getModel: () => DEFAULT_GEMINI_MODEL_AUTO, getPreviewFeatures: () => false, +======= + getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO), +>>>>>>> 37f128a10 (feat(routing): restrict numerical routing to Gemini 3 family (#18478)) getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50) getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true), getClassifierThreshold: vi.fn().mockResolvedValue(undefined), @@ -76,6 +82,32 @@ describe('NumericalClassifierStrategy', () => { expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); }); + it('should return null if the model is not a Gemini 3 model', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision).toBeNull(); + expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); + }); + + it('should return null if the model is explicitly a Gemini 2 model', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision).toBeNull(); + expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); + }); + it('should call generateJson with the correct parameters and wrapped user content', async () => { const mockApiResponse = { complexity_reasoning: 'Simple task', @@ -120,7 +152,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_FLASH_MODEL, + model: PREVIEW_GEMINI_FLASH_MODEL, metadata: { source: 'NumericalClassifier (Control)', latencyMs: expect.any(Number), @@ -146,7 +178,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_MODEL, + model: PREVIEW_GEMINI_MODEL, metadata: { source: 'NumericalClassifier (Control)', latencyMs: expect.any(Number), @@ -172,7 +204,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80 + model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80 metadata: { source: 'NumericalClassifier (Strict)', latencyMs: expect.any(Number), @@ -198,7 +230,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_MODEL, + model: PREVIEW_GEMINI_MODEL, metadata: { source: 'NumericalClassifier (Strict)', latencyMs: expect.any(Number), @@ -226,7 +258,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70 + model: PREVIEW_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70 metadata: { source: 'NumericalClassifier (Remote)', latencyMs: expect.any(Number), @@ -252,7 +284,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5 + model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5 metadata: { source: 'NumericalClassifier (Remote)', latencyMs: expect.any(Number), @@ -278,7 +310,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_MODEL, // Score 35 >= Threshold 30 + model: PREVIEW_GEMINI_MODEL, // Score 35 >= Threshold 30 metadata: { source: 'NumericalClassifier (Remote)', latencyMs: expect.any(Number), @@ -306,7 +338,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50 + model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50 metadata: { source: 'NumericalClassifier (Control)', latencyMs: expect.any(Number), @@ -333,7 +365,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_FLASH_MODEL, + model: PREVIEW_GEMINI_FLASH_MODEL, metadata: { source: 'NumericalClassifier (Control)', latencyMs: expect.any(Number), @@ -360,7 +392,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_MODEL, + model: PREVIEW_GEMINI_MODEL, metadata: { source: 'NumericalClassifier (Control)', latencyMs: expect.any(Number), diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts index 9bcaebf4321..9964e06dea4 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts @@ -12,7 +12,7 @@ import type { RoutingDecision, RoutingStrategy, } from '../routingStrategy.js'; -import { resolveClassifierModel } from '../../config/models.js'; +import { resolveClassifierModel, isGemini3Model } from '../../config/models.js'; import { createUserContent, Type } from '@google/genai'; import type { Config } from '../../config/config.js'; import { debugLogger } from '../../utils/debugLogger.js'; @@ -134,10 +134,15 @@ export class NumericalClassifierStrategy implements RoutingStrategy { ): Promise { const startTime = Date.now(); try { + const model = context.requestedModel ?? config.getModel(); if (!(await config.getNumericalRoutingEnabled())) { return null; } + if (!isGemini3Model(model)) { + return null; + } + const promptId = getPromptIdWithFallback('classifier-router'); const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT); @@ -176,11 +181,15 @@ export class NumericalClassifierStrategy implements RoutingStrategy { config.getSessionId() || 'unknown-session', ); +<<<<<<< HEAD const selectedModel = resolveClassifierModel( config.getModel(), modelAlias, config.getPreviewFeatures(), ); +======= + const selectedModel = resolveClassifierModel(model, modelAlias); +>>>>>>> 37f128a10 (feat(routing): restrict numerical routing to Gemini 3 family (#18478)) const latencyMs = Date.now() - startTime; From 459c398e1ef716d9641eeb9584b21eb1e04adb3e Mon Sep 17 00:00:00 2001 From: mkorwel Date: Tue, 10 Feb 2026 13:00:19 -0600 Subject: [PATCH 2/3] fix: update classifier strategies to support preview models in Gemini 3 check --- packages/core/src/config/models.ts | 8 +- .../strategies/classifierStrategy.test.ts | 93 ++++++++----------- .../routing/strategies/classifierStrategy.ts | 3 +- .../numericalClassifierStrategy.test.ts | 31 +++++-- .../strategies/numericalClassifierStrategy.ts | 9 +- 5 files changed, 75 insertions(+), 69 deletions(-) diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index 24d747b9696..5909449e8da 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -141,10 +141,14 @@ export function isPreviewModel(model: string): boolean { * Checks if the model is a Gemini 3 model. * * @param model The model name to check. + * @param previewFeaturesEnabled A boolean indicating if preview features are enabled. * @returns True if the model is a Gemini 3 model. */ -export function isGemini3Model(model: string): boolean { - const resolved = resolveModel(model); +export function isGemini3Model( + model: string, + previewFeaturesEnabled: boolean = false, +): boolean { + const resolved = resolveModel(model, previewFeaturesEnabled); return /^gemini-3(\.|-|$)/.test(resolved); } diff --git a/packages/core/src/routing/strategies/classifierStrategy.test.ts b/packages/core/src/routing/strategies/classifierStrategy.test.ts index 9b6c3e35d01..81426748bf1 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.test.ts @@ -4,15 +4,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest'; import { ClassifierStrategy } from './classifierStrategy.js'; import type { RoutingContext } from '../routingStrategy.js'; import type { Config } from '../../config/config.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; -import { - isFunctionCall, - isFunctionResponse, -} from '../../utils/messageInspectors.js'; import { DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_MODEL, @@ -32,6 +28,9 @@ describe('ClassifierStrategy', () => { let mockConfig: Config; let mockBaseLlmClient: BaseLlmClient; let mockResolvedConfig: ResolvedModelConfig; + let mockGetModel: Mock; + let mockGetNumericalRoutingEnabled: Mock; + let mockGenerateJson: Mock; beforeEach(() => { vi.clearAllMocks(); @@ -47,28 +46,30 @@ describe('ClassifierStrategy', () => { model: 'classifier', generateContentConfig: {}, } as unknown as ResolvedModelConfig; + + mockGetModel = vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO); + mockGetNumericalRoutingEnabled = vi.fn().mockResolvedValue(false); + mockGenerateJson = vi.fn(); + mockConfig = { modelConfigService: { getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig), }, -<<<<<<< HEAD - getModel: () => DEFAULT_GEMINI_MODEL_AUTO, + getModel: mockGetModel, getPreviewFeatures: () => false, -======= - getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO), ->>>>>>> 37f128a10 (feat(routing): restrict numerical routing to Gemini 3 family (#18478)) - getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false), + getNumericalRoutingEnabled: mockGetNumericalRoutingEnabled, } as unknown as Config; + mockBaseLlmClient = { - generateJson: vi.fn(), + generateJson: mockGenerateJson, } as unknown as BaseLlmClient; vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id'); }); it('should return null if numerical routing is enabled and model is Gemini 3', async () => { - vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true); - vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO); + mockGetNumericalRoutingEnabled.mockResolvedValue(true); + mockGetModel.mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO); const decision = await strategy.route( mockContext, @@ -77,13 +78,13 @@ describe('ClassifierStrategy', () => { ); expect(decision).toBeNull(); - expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); + expect(mockGenerateJson).not.toHaveBeenCalled(); }); it('should NOT return null if numerical routing is enabled but model is NOT Gemini 3', async () => { - vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true); - vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO); - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue({ + mockGetNumericalRoutingEnabled.mockResolvedValue(true); + mockGetModel.mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO); + mockGenerateJson.mockResolvedValue({ reasoning: 'test', model_choice: 'flash', }); @@ -95,7 +96,7 @@ describe('ClassifierStrategy', () => { ); expect(decision).not.toBeNull(); - expect(mockBaseLlmClient.generateJson).toHaveBeenCalled(); + expect(mockGenerateJson).toHaveBeenCalled(); }); it('should call generateJson with the correct parameters', async () => { @@ -103,13 +104,11 @@ describe('ClassifierStrategy', () => { reasoning: 'Simple task', model_choice: 'flash', }; - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( - mockApiResponse, - ); + mockGenerateJson.mockResolvedValue(mockApiResponse); await strategy.route(mockContext, mockConfig, mockBaseLlmClient); - expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith( + expect(mockGenerateJson).toHaveBeenCalledWith( expect.objectContaining({ modelConfigKey: { model: mockResolvedConfig.model }, promptId: 'test-prompt-id', @@ -122,9 +121,7 @@ describe('ClassifierStrategy', () => { reasoning: 'This is a simple task.', model_choice: 'flash', }; - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( - mockApiResponse, - ); + mockGenerateJson.mockResolvedValue(mockApiResponse); const decision = await strategy.route( mockContext, @@ -132,7 +129,7 @@ describe('ClassifierStrategy', () => { mockBaseLlmClient, ); - expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce(); + expect(mockGenerateJson).toHaveBeenCalledOnce(); expect(decision).toEqual({ model: DEFAULT_GEMINI_FLASH_MODEL, metadata: { @@ -148,9 +145,7 @@ describe('ClassifierStrategy', () => { reasoning: 'This is a complex task.', model_choice: 'pro', }; - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( - mockApiResponse, - ); + mockGenerateJson.mockResolvedValue(mockApiResponse); mockContext.request = [{ text: 'how do I build a spaceship?' }]; const decision = await strategy.route( @@ -159,7 +154,7 @@ describe('ClassifierStrategy', () => { mockBaseLlmClient, ); - expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce(); + expect(mockGenerateJson).toHaveBeenCalledOnce(); expect(decision).toEqual({ model: DEFAULT_GEMINI_MODEL, metadata: { @@ -175,7 +170,7 @@ describe('ClassifierStrategy', () => { .spyOn(debugLogger, 'warn') .mockImplementation(() => {}); const testError = new Error('API Failure'); - vi.mocked(mockBaseLlmClient.generateJson).mockRejectedValue(testError); + mockGenerateJson.mockRejectedValue(testError); const decision = await strategy.route( mockContext, @@ -196,9 +191,7 @@ describe('ClassifierStrategy', () => { reasoning: 'This is a simple task.', // model_choice is missing, which will cause a Zod parsing error. }; - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( - malformedApiResponse, - ); + mockGenerateJson.mockResolvedValue(malformedApiResponse); const decision = await strategy.route( mockContext, @@ -227,14 +220,11 @@ describe('ClassifierStrategy', () => { reasoning: 'Simple.', model_choice: 'flash', }; - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( - mockApiResponse, - ); + mockGenerateJson.mockResolvedValue(mockApiResponse); await strategy.route(mockContext, mockConfig, mockBaseLlmClient); - const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock - .calls[0][0]; + const generateJsonCall = mockGenerateJson.mock.calls[0][0]; const contents = generateJsonCall.contents; const expectedContents = [ @@ -263,14 +253,11 @@ describe('ClassifierStrategy', () => { reasoning: 'Simple.', model_choice: 'flash', }; - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( - mockApiResponse, - ); + mockGenerateJson.mockResolvedValue(mockApiResponse); await strategy.route(mockContext, mockConfig, mockBaseLlmClient); - const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock - .calls[0][0]; + const generateJsonCall = mockGenerateJson.mock.calls[0][0]; const contents = generateJsonCall.contents; // Manually calculate what the history should be @@ -278,7 +265,10 @@ describe('ClassifierStrategy', () => { const HISTORY_TURNS_FOR_CONTEXT = 4; const historySlice = longHistory.slice(-HISTORY_SEARCH_WINDOW); const cleanHistory = historySlice.filter( - (content) => !isFunctionCall(content) && !isFunctionResponse(content), + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (content: any) => + !content.parts?.[0]?.functionCall && + !content.parts?.[0]?.functionResponse, ); const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT); @@ -299,14 +289,11 @@ describe('ClassifierStrategy', () => { reasoning: 'Simple.', model_choice: 'flash', }; - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( - mockApiResponse, - ); + mockGenerateJson.mockResolvedValue(mockApiResponse); await strategy.route(mockContext, mockConfig, mockBaseLlmClient); - const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock - .calls[0][0]; + const generateJsonCall = mockGenerateJson.mock.calls[0][0]; expect(generateJsonCall.promptId).toMatch( /^classifier-router-fallback-\d+-\w+$/, @@ -325,9 +312,7 @@ describe('ClassifierStrategy', () => { reasoning: 'Choice is flash', model_choice: 'flash', }; - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( - mockApiResponse, - ); + mockGenerateJson.mockResolvedValue(mockApiResponse); const contextWithRequestedModel = { ...mockContext, diff --git a/packages/core/src/routing/strategies/classifierStrategy.ts b/packages/core/src/routing/strategies/classifierStrategy.ts index 8396f4aceb1..d2a8dee80a6 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.ts @@ -134,9 +134,10 @@ export class ClassifierStrategy implements RoutingStrategy { const startTime = Date.now(); try { const model = context.requestedModel ?? config.getModel(); + const previewFeaturesEnabled = config.getPreviewFeatures(); if ( (await config.getNumericalRoutingEnabled()) && - isGemini3Model(model) + isGemini3Model(model, previewFeaturesEnabled) ) { return null; } diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts index facae8a1863..07c28b38349 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts @@ -12,9 +12,10 @@ import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import { PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_MODEL, - PREVIEW_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL, + PREVIEW_GEMINI_MODEL_AUTO, + GEMINI_MODEL_ALIAS_AUTO, } from '../../config/models.js'; import { promptIdContext } from '../../utils/promptIdContext.js'; import type { Content } from '@google/genai'; @@ -48,12 +49,8 @@ describe('NumericalClassifierStrategy', () => { modelConfigService: { getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig), }, -<<<<<<< HEAD - getModel: () => DEFAULT_GEMINI_MODEL_AUTO, - getPreviewFeatures: () => false, -======= getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO), ->>>>>>> 37f128a10 (feat(routing): restrict numerical routing to Gemini 3 family (#18478)) + getPreviewFeatures: () => false, getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50) getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true), getClassifierThreshold: vi.fn().mockResolvedValue(undefined), @@ -108,6 +105,28 @@ describe('NumericalClassifierStrategy', () => { expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); }); + it('should return a decision if model is auto and preview features are enabled (resolves to Gemini 3)', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue(GEMINI_MODEL_ALIAS_AUTO); + vi.spyOn(mockConfig, 'getPreviewFeatures').mockReturnValue(true); + + const mockApiResponse = { + complexity_reasoning: 'Simple task', + complexity_score: 10, + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision).not.toBeNull(); + expect(mockBaseLlmClient.generateJson).toHaveBeenCalled(); + }); + it('should call generateJson with the correct parameters and wrapped user content', async () => { const mockApiResponse = { complexity_reasoning: 'Simple task', diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts index 9964e06dea4..0cd74fbb583 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts @@ -135,11 +135,12 @@ export class NumericalClassifierStrategy implements RoutingStrategy { const startTime = Date.now(); try { const model = context.requestedModel ?? config.getModel(); + const previewFeaturesEnabled = config.getPreviewFeatures(); if (!(await config.getNumericalRoutingEnabled())) { return null; } - if (!isGemini3Model(model)) { + if (!isGemini3Model(model, previewFeaturesEnabled)) { return null; } @@ -181,15 +182,11 @@ export class NumericalClassifierStrategy implements RoutingStrategy { config.getSessionId() || 'unknown-session', ); -<<<<<<< HEAD const selectedModel = resolveClassifierModel( - config.getModel(), + model, modelAlias, config.getPreviewFeatures(), ); -======= - const selectedModel = resolveClassifierModel(model, modelAlias); ->>>>>>> 37f128a10 (feat(routing): restrict numerical routing to Gemini 3 family (#18478)) const latencyMs = Date.now() - startTime; From 8116fb7717eba68fd25820458c1a59740b34e7fb Mon Sep 17 00:00:00 2001 From: mkorwel Date: Tue, 10 Feb 2026 13:28:48 -0600 Subject: [PATCH 3/3] fix: update test expectation for isGemini3Model with aliases --- packages/core/src/config/models.test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/core/src/config/models.test.ts b/packages/core/src/config/models.test.ts index 73865accd2e..c60abc80fb6 100644 --- a/packages/core/src/config/models.test.ts +++ b/packages/core/src/config/models.test.ts @@ -33,8 +33,8 @@ describe('isGemini3Model', () => { }); it('should return true for aliases that resolve to Gemini 3', () => { - expect(isGemini3Model(GEMINI_MODEL_ALIAS_AUTO)).toBe(true); - expect(isGemini3Model(GEMINI_MODEL_ALIAS_PRO)).toBe(true); + expect(isGemini3Model(GEMINI_MODEL_ALIAS_AUTO, true)).toBe(true); + expect(isGemini3Model(GEMINI_MODEL_ALIAS_PRO, true)).toBe(true); expect(isGemini3Model(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true); });