From 35819f6e151dbdf39a14be72f92f6e472f863f9d Mon Sep 17 00:00:00 2001 From: James Grugett Date: Sat, 25 Apr 2026 18:58:16 -0700 Subject: [PATCH] Block unverifiable free-mode countries (#551) --- cli/src/components/waiting-room-screen.tsx | 30 +- cli/src/hooks/use-freebuff-session.ts | 33 +- cli/src/utils/error-handling.ts | 4 +- common/src/types/freebuff-session.ts | 5 +- .../completions/__tests__/completions.test.ts | 617 +++++++++++------- web/src/app/api/v1/chat/completions/_post.ts | 199 +++--- .../session/__tests__/session.test.ts | 65 +- .../app/api/v1/freebuff/session/_handlers.ts | 32 +- .../__tests__/free-mode-country.test.ts | 45 ++ web/src/server/free-mode-country.ts | 111 +++- 10 files changed, 741 insertions(+), 400 deletions(-) create mode 100644 web/src/server/__tests__/free-mode-country.test.ts diff --git a/cli/src/components/waiting-room-screen.tsx b/cli/src/components/waiting-room-screen.tsx index d48d986d2a..3399786ec4 100644 --- a/cli/src/components/waiting-room-screen.tsx +++ b/cli/src/components/waiting-room-screen.tsx @@ -221,13 +221,13 @@ export const WaitingRoomScreen: React.FC = ({ / {session.queueDepth} - Wait + Wait {session.position === 1 ? 'any moment now' : formatWait(session.estimatedWaitMs)} - Elapsed + Elapsed {formatElapsed(elapsedMs)} {/* Per-model session quota (e.g. GLM 5.1 caps at 5/20h). Only @@ -237,7 +237,8 @@ export const WaitingRoomScreen: React.FC = ({ Sessions - {session.rateLimit.recentCount} / {session.rateLimit.limit} + {session.rateLimit.recentCount} /{' '} + {session.rateLimit.limit} used in last {session.rateLimit.windowHours}h @@ -262,10 +263,20 @@ export const WaitingRoomScreen: React.FC = ({ ⚠ Free mode isn't available in your region - We detected your location as{' '} - {session.countryCode}, - which is outside the countries where freebuff is currently - offered. Press Ctrl+C to exit. + {session.countryCode === 'UNKNOWN' ? ( + <> + We couldn't verify an eligible location for this request. + VPN, Tor, proxy, or unknown-location traffic can't use + freebuff. Press Ctrl+C to exit. + + ) : ( + <> + We detected your location as{' '} + {session.countryCode}, + which is outside the countries where freebuff is currently + offered. Press Ctrl+C to exit. + + )} )} @@ -279,8 +290,9 @@ export const WaitingRoomScreen: React.FC = ({ ⚠ Account unavailable - This account has been suspended and can't use freebuff. If you think this is a - mistake, contact support@codebuff.com. Press Ctrl+C to exit. + This account has been suspended and can't use freebuff. If you + think this is a mistake, contact support@codebuff.com. Press + Ctrl+C to exit. )} diff --git a/cli/src/hooks/use-freebuff-session.ts b/cli/src/hooks/use-freebuff-session.ts index 19f21ecaa2..5b5a205c84 100644 --- a/cli/src/hooks/use-freebuff-session.ts +++ b/cli/src/hooks/use-freebuff-session.ts @@ -38,7 +38,9 @@ const playAdmissionSound = () => { } const sessionEndpoint = (): string => { - const base = (env.NEXT_PUBLIC_CODEBUFF_APP_URL || 'https://codebuff.com').replace(/\/$/, '') + const base = ( + env.NEXT_PUBLIC_CODEBUFF_APP_URL || 'https://codebuff.com' + ).replace(/\/$/, '') return `${base}/api/v1/freebuff/session` } @@ -73,10 +75,13 @@ async function callSession( // generic error and back off on the 10s error-retry cadence instead of // tight-polling an unrecognized 200 body. if (resp.status === 403) { - const body = (await resp.json().catch(() => null)) as - | FreebuffSessionResponse - | null - if (body && (body.status === 'country_blocked' || body.status === 'banned')) { + const body = (await resp + .json() + .catch(() => null)) as FreebuffSessionResponse | null + if ( + body && + (body.status === 'country_blocked' || body.status === 'banned') + ) { return body } } @@ -85,9 +90,9 @@ async function callSession( // Surface model-switch conflicts and temporary model availability closures // as non-throw states. if (resp.status === 409 && method === 'POST') { - const body = (await resp.json().catch(() => null)) as - | FreebuffSessionResponse - | null + const body = (await resp + .json() + .catch(() => null)) as FreebuffSessionResponse | null if ( body && (body.status === 'model_locked' || body.status === 'model_unavailable') @@ -101,9 +106,9 @@ async function callSession( // status (rather than 200) keeps older CLIs in their error path so they // back off instead of tight-polling an unrecognized 200 body. if (resp.status === 429 && method === 'POST') { - const body = (await resp.json().catch(() => null)) as - | FreebuffSessionResponse - | null + const body = (await resp + .json() + .catch(() => null)) as FreebuffSessionResponse | null if (body && body.status === 'rate_limited') { return body } @@ -190,9 +195,7 @@ export function getFreebuffInstanceId(): string | undefined { * holding (queued, active, or in the post-expiry grace window with a live * instance id). DELETE only matters in those states; otherwise we'd fire a * spurious request the server has nothing to act on. */ -function shouldReleaseSlot( - current: FreebuffSessionResponse | null, -): boolean { +function shouldReleaseSlot(current: FreebuffSessionResponse | null): boolean { if (!current) return false return ( current.status === 'queued' || @@ -312,7 +315,7 @@ export function markFreebuffSessionSuperseded(): void { /** Flip into the terminal `country_blocked` state from outside the poll loop. * Used when the chat-completions gate rejects on country even though the - * session-level country check had failed open (null detection → admitted). + * session-level country check did not catch the request first. * Transitioning the session state here unmounts the Chat surface in favor of * the waiting-room's country_blocked message, so the user can't keep typing * and sending doomed requests. */ diff --git a/cli/src/utils/error-handling.ts b/cli/src/utils/error-handling.ts index 5bedce5d4a..9b624ea520 100644 --- a/cli/src/utils/error-handling.ts +++ b/cli/src/utils/error-handling.ts @@ -60,8 +60,8 @@ export const isFreeModeUnavailableError = (error: unknown): boolean => { /** * Extract the detected countryCode off a free_mode_unavailable error, if the * server included one. Used to populate the country_blocked screen after the - * chat-completions gate rejects a user whose session-level country check had - * previously failed open (null country detection → admitted → now blocked). + * chat-completions gate rejects a user whose session-level country check did + * not catch the request first. */ export const getCountryCodeFromFreeModeError = ( error: unknown, diff --git a/common/src/types/freebuff-session.ts b/common/src/types/freebuff-session.ts index 7789c91f22..7b5fc04922 100644 --- a/common/src/types/freebuff-session.ts +++ b/common/src/types/freebuff-session.ts @@ -98,11 +98,12 @@ export type FreebuffSessionServerResponse = status: 'superseded' } | { - /** Request originated from a country outside the free-mode allowlist. + /** Request originated outside the free-mode allowlist, or from an + * unknown/anonymized location that cannot be trusted for free mode. * Returned before queue admission so users don't wait through the * room only to be rejected on their first chat request. Terminal — * CLI stops polling and shows a "not available in your country" - * screen. `countryCode` is the resolved country for display. */ + * screen. `countryCode` is the resolved country, or UNKNOWN. */ status: 'country_blocked' countryCode: string } diff --git a/web/src/app/api/v1/chat/completions/__tests__/completions.test.ts b/web/src/app/api/v1/chat/completions/__tests__/completions.test.ts index 1aac8800cd..3e4a1149d1 100644 --- a/web/src/app/api/v1/chat/completions/__tests__/completions.test.ts +++ b/web/src/app/api/v1/chat/completions/__tests__/completions.test.ts @@ -19,10 +19,7 @@ import type { BlockGrantResult } from '@codebuff/billing/subscription' import type { GetUserPreferencesFn } from '../_post' describe('/api/v1/chat/completions POST endpoint', () => { - const mockUserData: Record< - string, - { id: string; banned: boolean } - > = { + const mockUserData: Record = { 'test-api-key-123': { id: 'user-123', banned: false, @@ -67,7 +64,12 @@ describe('/api/v1/chat/completions POST endpoint', () => { // flow without seeding a session. Matches the real return for the disabled // path so downstream logic proceeds normally. const mockCheckSessionAdmissibleAllow = async () => - ({ ok: true, reason: 'disabled' } as const) + ({ ok: true, reason: 'disabled' }) as const + + const allowedFreeModeHeaders = (apiKey: string) => ({ + Authorization: `Bearer ${apiKey}`, + 'cf-ipcountry': 'US', + }) beforeEach(() => { nextQuotaReset = new Date( @@ -75,15 +77,15 @@ describe('/api/v1/chat/completions POST endpoint', () => { ).toISOString() mockLogger = { - error: mock(() => { }), - warn: mock(() => { }), - info: mock(() => { }), - debug: mock(() => { }), + error: mock(() => {}), + warn: mock(() => {}), + info: mock(() => {}), + debug: mock(() => {}), } mockLoggerWithContext = mock(() => mockLogger) - mockTrackEvent = mock(() => { }) + mockTrackEvent = mock(() => {}) mockGetUserUsageData = mock(async ({ userId }: { userId: string }) => { if (userId === 'user-no-credits') { @@ -485,7 +487,6 @@ describe('/api/v1/chat/completions POST endpoint', () => { expect(response.status).toBe(200) }) - it('lets a BYOK free-tier new account through the paid-plan gate', async () => { const req = new NextRequest( 'http://localhost:3000/api/v1/chat/completions', @@ -527,7 +528,7 @@ describe('/api/v1/chat/completions POST endpoint', () => { 'http://localhost:3000/api/v1/chat/completions', { method: 'POST', - headers: { Authorization: 'Bearer test-api-key-new-free' }, + headers: allowedFreeModeHeaders('test-api-key-new-free'), body: JSON.stringify({ model: 'minimax/minimax-m2.7', stream: false, @@ -556,6 +557,84 @@ describe('/api/v1/chat/completions POST endpoint', () => { expect(response.status).toBe(200) }) + it('rejects free-mode requests when location is unknown', async () => { + const req = new NextRequest( + 'http://localhost:3000/api/v1/chat/completions', + { + method: 'POST', + headers: { Authorization: 'Bearer test-api-key-new-free' }, + body: JSON.stringify({ + model: 'minimax/minimax-m2.7', + stream: false, + codebuff_metadata: { + run_id: 'run-free', + client_id: 'test-client-id-123', + cost_mode: 'free', + }, + }), + }, + ) + + const response = await postChatCompletions({ + req, + getUserInfoFromApiKey: mockGetUserInfoFromApiKey, + logger: mockLogger, + trackEvent: mockTrackEvent, + getUserUsageData: mockGetUserUsageData, + getAgentRunFromId: mockGetAgentRunFromId, + fetch: mockFetch, + insertMessageBigquery: mockInsertMessageBigquery, + loggerWithContext: mockLoggerWithContext, + checkSessionAdmissible: mockCheckSessionAdmissibleAllow, + }) + + expect(response.status).toBe(403) + const body = await response.json() + expect(body.error).toBe('free_mode_unavailable') + expect(body.countryCode).toBe('UNKNOWN') + }) + + it('rejects free-mode requests from anonymized Cloudflare country codes', async () => { + const req = new NextRequest( + 'http://localhost:3000/api/v1/chat/completions', + { + method: 'POST', + headers: { + Authorization: 'Bearer test-api-key-new-free', + 'cf-ipcountry': 'T1', + 'x-forwarded-for': '8.8.8.8', + }, + body: JSON.stringify({ + model: 'minimax/minimax-m2.7', + stream: false, + codebuff_metadata: { + run_id: 'run-free', + client_id: 'test-client-id-123', + cost_mode: 'free', + }, + }), + }, + ) + + const response = await postChatCompletions({ + req, + getUserInfoFromApiKey: mockGetUserInfoFromApiKey, + logger: mockLogger, + trackEvent: mockTrackEvent, + getUserUsageData: mockGetUserUsageData, + getAgentRunFromId: mockGetAgentRunFromId, + fetch: mockFetch, + insertMessageBigquery: mockInsertMessageBigquery, + loggerWithContext: mockLoggerWithContext, + checkSessionAdmissible: mockCheckSessionAdmissibleAllow, + }) + + expect(response.status).toBe(403) + const body = await response.json() + expect(body.error).toBe('free_mode_unavailable') + expect(body.countryCode).toBe('UNKNOWN') + }) + it('lets freebuff use GLM 5.1 through Fireworks availability rules', async () => { const fetchedBodies: Record[] = [] const fetchViaFireworks = mock( @@ -584,7 +663,7 @@ describe('/api/v1/chat/completions POST endpoint', () => { 'http://localhost:3000/api/v1/chat/completions', { method: 'POST', - headers: { Authorization: 'Bearer test-api-key-new-free' }, + headers: allowedFreeModeHeaders('test-api-key-new-free'), body: JSON.stringify({ model: 'z-ai/glm-5.1', stream: false, @@ -631,7 +710,7 @@ describe('/api/v1/chat/completions POST endpoint', () => { 'http://localhost:3000/api/v1/chat/completions', { method: 'POST', - headers: { Authorization: 'Bearer test-api-key-no-credits' }, + headers: allowedFreeModeHeaders('test-api-key-no-credits'), body: JSON.stringify({ model: 'minimax/minimax-m2.7', stream: false, @@ -665,7 +744,7 @@ describe('/api/v1/chat/completions POST endpoint', () => { 'http://localhost:3000/api/v1/chat/completions', { method: 'POST', - headers: { Authorization: 'Bearer test-api-key-new-free' }, + headers: allowedFreeModeHeaders('test-api-key-new-free'), body: JSON.stringify({ // Expensive model the attacker wants for free. model: 'anthropic/claude-4.7-opus', @@ -704,7 +783,7 @@ describe('/api/v1/chat/completions POST endpoint', () => { 'http://localhost:3000/api/v1/chat/completions', { method: 'POST', - headers: { Authorization: 'Bearer test-api-key-new-free' }, + headers: allowedFreeModeHeaders('test-api-key-new-free'), body: JSON.stringify({ model: 'anthropic/claude-4.7-opus', stream: true, @@ -740,7 +819,7 @@ describe('/api/v1/chat/completions POST endpoint', () => { 'http://localhost:3000/api/v1/chat/completions', { method: 'POST', - headers: { Authorization: 'Bearer test-api-key-new-free' }, + headers: allowedFreeModeHeaders('test-api-key-new-free'), body: JSON.stringify({ model: 'minimax/minimax-m2.7', stream: true, @@ -872,183 +951,211 @@ describe('/api/v1/chat/completions POST endpoint', () => { }), }) - it('returns 429 when weekly limit reached and fallback disabled', async () => { - const weeklyLimitError: BlockGrantResult = { - error: 'weekly_limit_reached', - used: 3500, - limit: 3500, - resetsAt: new Date(Date.now() + 3 * 24 * 60 * 60 * 1000), - } - const mockEnsureSubscriberBlockGrant = mock(async () => weeklyLimitError) - const mockGetUserPreferences: GetUserPreferencesFn = mock(async () => ({ - fallbackToALaCarte: false, - })) - - const response = await postChatCompletions({ - req: createValidRequest(), - getUserInfoFromApiKey: mockGetUserInfoFromApiKey, - logger: mockLogger, - trackEvent: mockTrackEvent, - getUserUsageData: mockGetUserUsageData, - getAgentRunFromId: mockGetAgentRunFromId, - fetch: mockFetch, - insertMessageBigquery: mockInsertMessageBigquery, - loggerWithContext: mockLoggerWithContext, - ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, - getUserPreferences: mockGetUserPreferences, - checkSessionAdmissible: mockCheckSessionAdmissibleAllow, - }) - - expect(response.status).toBe(429) - const body = await response.json() - expect(body.error).toBe('rate_limit_exceeded') - expect(body.message).toContain('weekly limit reached') - expect(body.message).toContain('Enable "Continue with credits"') - }, SUBSCRIPTION_TEST_TIMEOUT_MS) - - it('skips subscription limit check when in FREE mode even with fallback disabled', async () => { - const weeklyLimitError: BlockGrantResult = { - error: 'weekly_limit_reached', - used: 3500, - limit: 3500, - resetsAt: new Date(Date.now() + 3 * 24 * 60 * 60 * 1000), - } - const mockEnsureSubscriberBlockGrant = mock(async () => weeklyLimitError) - const mockGetUserPreferences: GetUserPreferencesFn = mock(async () => ({ - fallbackToALaCarte: false, - })) - - const freeModeRequest = new NextRequest( - 'http://localhost:3000/api/v1/chat/completions', - { - method: 'POST', - headers: { Authorization: 'Bearer test-api-key-123' }, - body: JSON.stringify({ - model: 'minimax/minimax-m2.7', - stream: false, - codebuff_metadata: { - run_id: 'run-free', - client_id: 'test-client-id-123', - cost_mode: 'free', - }, - }), - }, - ) - - const response = await postChatCompletions({ - req: freeModeRequest, - getUserInfoFromApiKey: mockGetUserInfoFromApiKey, - logger: mockLogger, - trackEvent: mockTrackEvent, - getUserUsageData: mockGetUserUsageData, - getAgentRunFromId: mockGetAgentRunFromId, - fetch: mockFetch, - insertMessageBigquery: mockInsertMessageBigquery, - loggerWithContext: mockLoggerWithContext, - ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, - getUserPreferences: mockGetUserPreferences, - checkSessionAdmissible: mockCheckSessionAdmissibleAllow, - }) - - expect(response.status).toBe(200) - }, SUBSCRIPTION_TEST_TIMEOUT_MS) - - it('returns 429 when block exhausted and fallback disabled', async () => { - const blockExhaustedError: BlockGrantResult = { - error: 'block_exhausted', - blockUsed: 350, - blockLimit: 350, - resetsAt: new Date(Date.now() + 4 * 60 * 60 * 1000), - } - const mockEnsureSubscriberBlockGrant = mock(async () => blockExhaustedError) - const mockGetUserPreferences: GetUserPreferencesFn = mock(async () => ({ - fallbackToALaCarte: false, - })) + it( + 'returns 429 when weekly limit reached and fallback disabled', + async () => { + const weeklyLimitError: BlockGrantResult = { + error: 'weekly_limit_reached', + used: 3500, + limit: 3500, + resetsAt: new Date(Date.now() + 3 * 24 * 60 * 60 * 1000), + } + const mockEnsureSubscriberBlockGrant = mock( + async () => weeklyLimitError, + ) + const mockGetUserPreferences: GetUserPreferencesFn = mock(async () => ({ + fallbackToALaCarte: false, + })) + + const response = await postChatCompletions({ + req: createValidRequest(), + getUserInfoFromApiKey: mockGetUserInfoFromApiKey, + logger: mockLogger, + trackEvent: mockTrackEvent, + getUserUsageData: mockGetUserUsageData, + getAgentRunFromId: mockGetAgentRunFromId, + fetch: mockFetch, + insertMessageBigquery: mockInsertMessageBigquery, + loggerWithContext: mockLoggerWithContext, + ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, + getUserPreferences: mockGetUserPreferences, + checkSessionAdmissible: mockCheckSessionAdmissibleAllow, + }) - const response = await postChatCompletions({ - req: createValidRequest(), - getUserInfoFromApiKey: mockGetUserInfoFromApiKey, - logger: mockLogger, - trackEvent: mockTrackEvent, - getUserUsageData: mockGetUserUsageData, - getAgentRunFromId: mockGetAgentRunFromId, - fetch: mockFetch, - insertMessageBigquery: mockInsertMessageBigquery, - loggerWithContext: mockLoggerWithContext, - ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, - getUserPreferences: mockGetUserPreferences, - checkSessionAdmissible: mockCheckSessionAdmissibleAllow, - }) + expect(response.status).toBe(429) + const body = await response.json() + expect(body.error).toBe('rate_limit_exceeded') + expect(body.message).toContain('weekly limit reached') + expect(body.message).toContain('Enable "Continue with credits"') + }, + SUBSCRIPTION_TEST_TIMEOUT_MS, + ) + + it( + 'skips subscription limit check when in FREE mode even with fallback disabled', + async () => { + const weeklyLimitError: BlockGrantResult = { + error: 'weekly_limit_reached', + used: 3500, + limit: 3500, + resetsAt: new Date(Date.now() + 3 * 24 * 60 * 60 * 1000), + } + const mockEnsureSubscriberBlockGrant = mock( + async () => weeklyLimitError, + ) + const mockGetUserPreferences: GetUserPreferencesFn = mock(async () => ({ + fallbackToALaCarte: false, + })) - expect(response.status).toBe(429) - const body = await response.json() - expect(body.error).toBe('rate_limit_exceeded') - expect(body.message).toContain('5-hour session limit reached') - expect(body.message).toContain('Enable "Continue with credits"') - }, SUBSCRIPTION_TEST_TIMEOUT_MS) - - it('continues when weekly limit reached but fallback is enabled', async () => { - const weeklyLimitError: BlockGrantResult = { - error: 'weekly_limit_reached', - used: 3500, - limit: 3500, - resetsAt: new Date(Date.now() + 3 * 24 * 60 * 60 * 1000), - } - const mockEnsureSubscriberBlockGrant = mock(async () => weeklyLimitError) - const mockGetUserPreferences: GetUserPreferencesFn = mock(async () => ({ - fallbackToALaCarte: true, - })) + const freeModeRequest = new NextRequest( + 'http://localhost:3000/api/v1/chat/completions', + { + method: 'POST', + headers: allowedFreeModeHeaders('test-api-key-123'), + body: JSON.stringify({ + model: 'minimax/minimax-m2.7', + stream: false, + codebuff_metadata: { + run_id: 'run-free', + client_id: 'test-client-id-123', + cost_mode: 'free', + }, + }), + }, + ) - const response = await postChatCompletions({ - req: createValidRequest(), - getUserInfoFromApiKey: mockGetUserInfoFromApiKey, - logger: mockLogger, - trackEvent: mockTrackEvent, - getUserUsageData: mockGetUserUsageData, - getAgentRunFromId: mockGetAgentRunFromId, - fetch: mockFetch, - insertMessageBigquery: mockInsertMessageBigquery, - loggerWithContext: mockLoggerWithContext, - ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, - getUserPreferences: mockGetUserPreferences, - checkSessionAdmissible: mockCheckSessionAdmissibleAllow, - }) + const response = await postChatCompletions({ + req: freeModeRequest, + getUserInfoFromApiKey: mockGetUserInfoFromApiKey, + logger: mockLogger, + trackEvent: mockTrackEvent, + getUserUsageData: mockGetUserUsageData, + getAgentRunFromId: mockGetAgentRunFromId, + fetch: mockFetch, + insertMessageBigquery: mockInsertMessageBigquery, + loggerWithContext: mockLoggerWithContext, + ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, + getUserPreferences: mockGetUserPreferences, + checkSessionAdmissible: mockCheckSessionAdmissibleAllow, + }) - expect(response.status).toBe(200) - expect(mockLogger.info).toHaveBeenCalled() - }, SUBSCRIPTION_TEST_TIMEOUT_MS) + expect(response.status).toBe(200) + }, + SUBSCRIPTION_TEST_TIMEOUT_MS, + ) + + it( + 'returns 429 when block exhausted and fallback disabled', + async () => { + const blockExhaustedError: BlockGrantResult = { + error: 'block_exhausted', + blockUsed: 350, + blockLimit: 350, + resetsAt: new Date(Date.now() + 4 * 60 * 60 * 1000), + } + const mockEnsureSubscriberBlockGrant = mock( + async () => blockExhaustedError, + ) + const mockGetUserPreferences: GetUserPreferencesFn = mock(async () => ({ + fallbackToALaCarte: false, + })) + + const response = await postChatCompletions({ + req: createValidRequest(), + getUserInfoFromApiKey: mockGetUserInfoFromApiKey, + logger: mockLogger, + trackEvent: mockTrackEvent, + getUserUsageData: mockGetUserUsageData, + getAgentRunFromId: mockGetAgentRunFromId, + fetch: mockFetch, + insertMessageBigquery: mockInsertMessageBigquery, + loggerWithContext: mockLoggerWithContext, + ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, + getUserPreferences: mockGetUserPreferences, + checkSessionAdmissible: mockCheckSessionAdmissibleAllow, + }) - it('continues when block grant is created successfully', async () => { - const blockGrant: BlockGrantResult = { - grantId: 'block-123', - credits: 350, - expiresAt: new Date(Date.now() + 5 * 60 * 60 * 1000), - isNew: true, - } - const mockEnsureSubscriberBlockGrant = mock(async () => blockGrant) - const mockGetUserPreferences: GetUserPreferencesFn = mock(async () => ({ - fallbackToALaCarte: false, - })) + expect(response.status).toBe(429) + const body = await response.json() + expect(body.error).toBe('rate_limit_exceeded') + expect(body.message).toContain('5-hour session limit reached') + expect(body.message).toContain('Enable "Continue with credits"') + }, + SUBSCRIPTION_TEST_TIMEOUT_MS, + ) + + it( + 'continues when weekly limit reached but fallback is enabled', + async () => { + const weeklyLimitError: BlockGrantResult = { + error: 'weekly_limit_reached', + used: 3500, + limit: 3500, + resetsAt: new Date(Date.now() + 3 * 24 * 60 * 60 * 1000), + } + const mockEnsureSubscriberBlockGrant = mock( + async () => weeklyLimitError, + ) + const mockGetUserPreferences: GetUserPreferencesFn = mock(async () => ({ + fallbackToALaCarte: true, + })) + + const response = await postChatCompletions({ + req: createValidRequest(), + getUserInfoFromApiKey: mockGetUserInfoFromApiKey, + logger: mockLogger, + trackEvent: mockTrackEvent, + getUserUsageData: mockGetUserUsageData, + getAgentRunFromId: mockGetAgentRunFromId, + fetch: mockFetch, + insertMessageBigquery: mockInsertMessageBigquery, + loggerWithContext: mockLoggerWithContext, + ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, + getUserPreferences: mockGetUserPreferences, + checkSessionAdmissible: mockCheckSessionAdmissibleAllow, + }) - const response = await postChatCompletions({ - req: createValidRequest(), - getUserInfoFromApiKey: mockGetUserInfoFromApiKey, - logger: mockLogger, - trackEvent: mockTrackEvent, - getUserUsageData: mockGetUserUsageData, - getAgentRunFromId: mockGetAgentRunFromId, - fetch: mockFetch, - insertMessageBigquery: mockInsertMessageBigquery, - loggerWithContext: mockLoggerWithContext, - ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, - getUserPreferences: mockGetUserPreferences, - checkSessionAdmissible: mockCheckSessionAdmissibleAllow, - }) + expect(response.status).toBe(200) + expect(mockLogger.info).toHaveBeenCalled() + }, + SUBSCRIPTION_TEST_TIMEOUT_MS, + ) + + it( + 'continues when block grant is created successfully', + async () => { + const blockGrant: BlockGrantResult = { + grantId: 'block-123', + credits: 350, + expiresAt: new Date(Date.now() + 5 * 60 * 60 * 1000), + isNew: true, + } + const mockEnsureSubscriberBlockGrant = mock(async () => blockGrant) + const mockGetUserPreferences: GetUserPreferencesFn = mock(async () => ({ + fallbackToALaCarte: false, + })) + + const response = await postChatCompletions({ + req: createValidRequest(), + getUserInfoFromApiKey: mockGetUserInfoFromApiKey, + logger: mockLogger, + trackEvent: mockTrackEvent, + getUserUsageData: mockGetUserUsageData, + getAgentRunFromId: mockGetAgentRunFromId, + fetch: mockFetch, + insertMessageBigquery: mockInsertMessageBigquery, + loggerWithContext: mockLoggerWithContext, + ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, + getUserPreferences: mockGetUserPreferences, + checkSessionAdmissible: mockCheckSessionAdmissibleAllow, + }) - expect(response.status).toBe(200) - // getUserPreferences should not be called when block grant succeeds - expect(mockGetUserPreferences).not.toHaveBeenCalled() - }, SUBSCRIPTION_TEST_TIMEOUT_MS) + expect(response.status).toBe(200) + // getUserPreferences should not be called when block grant succeeds + expect(mockGetUserPreferences).not.toHaveBeenCalled() + }, + SUBSCRIPTION_TEST_TIMEOUT_MS, + ) it.skip('continues when ensureSubscriberBlockGrant throws an error (fail open)', async () => { const mockEnsureSubscriberBlockGrant = mock(async () => { @@ -1078,58 +1185,68 @@ describe('/api/v1/chat/completions POST endpoint', () => { expect(mockLogger.error).toHaveBeenCalled() }) - it.skip('continues when user is not a subscriber (null result)', async () => { - const mockEnsureSubscriberBlockGrant = mock(async () => null) - const mockGetUserPreferences: GetUserPreferencesFn = mock(async () => ({ - fallbackToALaCarte: false, - })) - - const response = await postChatCompletions({ - req: createValidRequest(), - getUserInfoFromApiKey: mockGetUserInfoFromApiKey, - logger: mockLogger, - trackEvent: mockTrackEvent, - getUserUsageData: mockGetUserUsageData, - getAgentRunFromId: mockGetAgentRunFromId, - fetch: mockFetch, - insertMessageBigquery: mockInsertMessageBigquery, - loggerWithContext: mockLoggerWithContext, - ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, - getUserPreferences: mockGetUserPreferences, - checkSessionAdmissible: mockCheckSessionAdmissibleAllow, - }) + it.skip( + 'continues when user is not a subscriber (null result)', + async () => { + const mockEnsureSubscriberBlockGrant = mock(async () => null) + const mockGetUserPreferences: GetUserPreferencesFn = mock(async () => ({ + fallbackToALaCarte: false, + })) + + const response = await postChatCompletions({ + req: createValidRequest(), + getUserInfoFromApiKey: mockGetUserInfoFromApiKey, + logger: mockLogger, + trackEvent: mockTrackEvent, + getUserUsageData: mockGetUserUsageData, + getAgentRunFromId: mockGetAgentRunFromId, + fetch: mockFetch, + insertMessageBigquery: mockInsertMessageBigquery, + loggerWithContext: mockLoggerWithContext, + ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, + getUserPreferences: mockGetUserPreferences, + checkSessionAdmissible: mockCheckSessionAdmissibleAllow, + }) - expect(response.status).toBe(200) - // getUserPreferences should not be called for non-subscribers - expect(mockGetUserPreferences).not.toHaveBeenCalled() - }, SUBSCRIPTION_TEST_TIMEOUT_MS) - - it.skip('defaults to allowing fallback when getUserPreferences is not provided', async () => { - const weeklyLimitError: BlockGrantResult = { - error: 'weekly_limit_reached', - used: 3500, - limit: 3500, - resetsAt: new Date(Date.now() + 3 * 24 * 60 * 60 * 1000), - } - const mockEnsureSubscriberBlockGrant = mock(async () => weeklyLimitError) + expect(response.status).toBe(200) + // getUserPreferences should not be called for non-subscribers + expect(mockGetUserPreferences).not.toHaveBeenCalled() + }, + SUBSCRIPTION_TEST_TIMEOUT_MS, + ) + + it.skip( + 'defaults to allowing fallback when getUserPreferences is not provided', + async () => { + const weeklyLimitError: BlockGrantResult = { + error: 'weekly_limit_reached', + used: 3500, + limit: 3500, + resetsAt: new Date(Date.now() + 3 * 24 * 60 * 60 * 1000), + } + const mockEnsureSubscriberBlockGrant = mock( + async () => weeklyLimitError, + ) - const response = await postChatCompletions({ - req: createValidRequest(), - getUserInfoFromApiKey: mockGetUserInfoFromApiKey, - logger: mockLogger, - trackEvent: mockTrackEvent, - getUserUsageData: mockGetUserUsageData, - getAgentRunFromId: mockGetAgentRunFromId, - fetch: mockFetch, - insertMessageBigquery: mockInsertMessageBigquery, - loggerWithContext: mockLoggerWithContext, - ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, - // Note: getUserPreferences is NOT provided - }) + const response = await postChatCompletions({ + req: createValidRequest(), + getUserInfoFromApiKey: mockGetUserInfoFromApiKey, + logger: mockLogger, + trackEvent: mockTrackEvent, + getUserUsageData: mockGetUserUsageData, + getAgentRunFromId: mockGetAgentRunFromId, + fetch: mockFetch, + insertMessageBigquery: mockInsertMessageBigquery, + loggerWithContext: mockLoggerWithContext, + ensureSubscriberBlockGrant: mockEnsureSubscriberBlockGrant, + // Note: getUserPreferences is NOT provided + }) - // Should continue processing (default to allowing a-la-carte) - expect(response.status).toBe(200) - }, SUBSCRIPTION_TEST_TIMEOUT_MS) + // Should continue processing (default to allowing a-la-carte) + expect(response.status).toBe(200) + }, + SUBSCRIPTION_TEST_TIMEOUT_MS, + ) it.skip('allows subscriber with 0 a-la-carte credits but active block grant', async () => { const blockGrant: BlockGrantResult = { @@ -1141,17 +1258,23 @@ describe('/api/v1/chat/completions POST endpoint', () => { const mockEnsureSubscriberBlockGrant = mock(async () => blockGrant) // Override mock: when subscription credits are included, simulate the block grant's credits - mockGetUserUsageData = mock(async ({ includeSubscriptionCredits }: { includeSubscriptionCredits?: boolean }) => ({ - usageThisCycle: 0, - balance: { - totalRemaining: includeSubscriptionCredits ? 350 : 0, - totalDebt: 0, - netBalance: includeSubscriptionCredits ? 350 : 0, - breakdown: {}, - principals: { subscription: 350 }, - }, - nextQuotaReset, - })) + mockGetUserUsageData = mock( + async ({ + includeSubscriptionCredits, + }: { + includeSubscriptionCredits?: boolean + }) => ({ + usageThisCycle: 0, + balance: { + totalRemaining: includeSubscriptionCredits ? 350 : 0, + totalDebt: 0, + netBalance: includeSubscriptionCredits ? 350 : 0, + breakdown: {}, + principals: { subscription: 350 }, + }, + nextQuotaReset, + }), + ) // Use the no-credits user (totalRemaining = 0 without subscription) const req = new NextRequest( diff --git a/web/src/app/api/v1/chat/completions/_post.ts b/web/src/app/api/v1/chat/completions/_post.ts index 13baada653..426f65e187 100644 --- a/web/src/app/api/v1/chat/completions/_post.ts +++ b/web/src/app/api/v1/chat/completions/_post.ts @@ -7,7 +7,6 @@ import { import { getErrorObject } from '@codebuff/common/util/error' import { pluralize } from '@codebuff/common/util/string' import { env } from '@codebuff/internal/env' -import geoip from 'geoip-lite' import { NextResponse } from 'next/server' import type { TrackEventFn } from '@codebuff/common/types/contracts/analytics' @@ -22,9 +21,7 @@ import type { LoggerWithContextFn, } from '@codebuff/common/types/contracts/logger' -import type { - BlockGrantResult, -} from '@codebuff/billing/subscription' +import type { BlockGrantResult } from '@codebuff/billing/subscription' import { isWeeklyLimitError, isBlockExhaustedError, @@ -68,11 +65,7 @@ import { OpenRouterError, } from '@/llm-api/openrouter' import { checkSessionAdmissible } from '@/server/free-session/public-api' -import { - FREE_MODE_ALLOWED_COUNTRIES, - extractClientIp, - getCountryCode, -} from '@/server/free-mode-country' +import { getFreeModeCountryAccess } from '@/server/free-mode-country' import type { SessionGateResult } from '@/server/free-session/public-api' import { extractApiKeyFromHeader } from '@/util/auth' @@ -138,7 +131,10 @@ export async function postChatCompletions(params: { getAgentRunFromId: GetAgentRunFromIdFn fetch: typeof globalThis.fetch insertMessageBigquery: InsertMessageBigqueryFn - ensureSubscriberBlockGrant?: (params: { userId: string; logger: Logger }) => Promise + ensureSubscriberBlockGrant?: (params: { + userId: string + logger: Logger + }) => Promise getUserPreferences?: GetUserPreferencesFn /** Optional override for the freebuff waiting-room gate. Defaults to the * real check backed by Postgres; tests inject a no-op. */ @@ -187,7 +183,9 @@ export async function postChatCompletions(params: { const costMode = typedBody.codebuff_metadata?.cost_mode const isFreeModeRequest = isFreeMode(costMode) - trackEvent = withDefaultProperties(trackEvent, { freebuff: isFreeModeRequest }) + trackEvent = withDefaultProperties(trackEvent, { + freebuff: isFreeModeRequest, + }) // Extract and validate API key const apiKey = extractApiKeyFromHeader(req) @@ -256,28 +254,30 @@ export async function postChatCompletions(params: { logger, }) - // For free mode requests, check if user is in US or Canada + // For free mode requests, require a resolved allowlisted country. if (isFreeModeRequest) { - const countryCode = getCountryCode(req) - const clientIp = extractClientIp(req) + const countryAccess = getFreeModeCountryAccess(req) - const cfHeader = req.headers.get('cf-ipcountry') - const geoipResult = clientIp ? geoip.lookup(clientIp)?.country ?? null : null logger.info( - { cfHeader, geoipResult, resolvedCountry: countryCode, clientIp: clientIp ? '[redacted]' : undefined }, + { + cfHeader: countryAccess.cfCountry, + geoipResult: countryAccess.geoipCountry, + resolvedCountry: countryAccess.countryCode, + countryBlockReason: countryAccess.blockReason, + clientIp: countryAccess.hasClientIp ? '[redacted]' : undefined, + }, 'Free mode country detection', ) - // If we couldn't determine country (null), allow the request (fail open) - // This handles users behind VPNs, corporate proxies, or localhost - if (countryCode && !FREE_MODE_ALLOWED_COUNTRIES.has(countryCode)) { + if (!countryAccess.allowed) { trackEvent({ event: AnalyticsEvent.CHAT_COMPLETIONS_VALIDATION_ERROR, userId, properties: { error: 'free_mode_not_available_in_country', - countryCode, - clientIp: clientIp ? '[redacted]' : undefined, + countryCode: countryAccess.countryCode, + countryBlockReason: countryAccess.blockReason, + clientIp: countryAccess.hasClientIp ? '[redacted]' : undefined, }, logger, }) @@ -286,12 +286,11 @@ export async function postChatCompletions(params: { { error: 'free_mode_unavailable', message: 'Free mode is not available in your country.', - countryCode, + countryCode: countryAccess.countryCode ?? 'UNKNOWN', }, { status: 403 }, ) } - } // Extract and validate agent run ID @@ -417,7 +416,9 @@ export async function postChatCompletions(params: { const rateLimitResult = checkFreeModeRateLimit(userId) if (rateLimitResult.limited) { const retryAfterSeconds = Math.ceil(rateLimitResult.retryAfterMs / 1000) - const resetTime = new Date(Date.now() + rateLimitResult.retryAfterMs).toISOString() + const resetTime = new Date( + Date.now() + rateLimitResult.retryAfterMs, + ).toISOString() const resetCountdown = formatQuotaResetCountdown(resetTime) trackEvent({ @@ -451,10 +452,17 @@ export async function postChatCompletions(params: { const includeSubscriptionCredits = !!ensureSubscriberBlockGrant if (ensureSubscriberBlockGrant) { try { - const blockGrantResult = await ensureSubscriberBlockGrant({ userId, logger }) + const blockGrantResult = await ensureSubscriberBlockGrant({ + userId, + logger, + }) // Check if user hit subscription limit and should be rate-limited - if (blockGrantResult && (isWeeklyLimitError(blockGrantResult) || isBlockExhaustedError(blockGrantResult))) { + if ( + blockGrantResult && + (isWeeklyLimitError(blockGrantResult) || + isBlockExhaustedError(blockGrantResult)) + ) { // Fetch user's preference for falling back to a-la-carte credits const preferences = getUserPreferences ? await getUserPreferences({ userId, logger }) @@ -462,8 +470,12 @@ export async function postChatCompletions(params: { if (!preferences.fallbackToALaCarte && !isFreeModeRequest) { const resetTime = blockGrantResult.resetsAt - const resetCountdown = formatQuotaResetCountdown(resetTime.toISOString()) - const limitType = isWeeklyLimitError(blockGrantResult) ? 'weekly' : '5-hour session' + const resetCountdown = formatQuotaResetCountdown( + resetTime.toISOString(), + ) + const limitType = isWeeklyLimitError(blockGrantResult) + ? 'weekly' + : '5-hour session' trackEvent({ event: AnalyticsEvent.CHAT_COMPLETIONS_INSUFFICIENT_CREDITS, @@ -486,7 +498,12 @@ export async function postChatCompletions(params: { } // If fallbackToALaCarte is true, continue to use a-la-carte credits logger.info( - { userId, limitType: isWeeklyLimitError(blockGrantResult) ? 'weekly' : 'session' }, + { + userId, + limitType: isWeeklyLimitError(blockGrantResult) + ? 'weekly' + : 'session', + }, 'Subscriber hit limit, falling back to a-la-carte credits', ) } @@ -535,19 +552,11 @@ export async function postChatCompletions(params: { const useCanopyWave = isCanopyWaveModel(typedBody.model) const useFireworks = !useCanopyWave && isFireworksModel(typedBody.model) const useOpenAIDirect = - !useCanopyWave && !useFireworks && isOpenAIDirectModel(typedBody.model) + !useCanopyWave && + !useFireworks && + isOpenAIDirectModel(typedBody.model) const stream = useSiliconFlow ? await handleSiliconFlowStream({ - body: typedBody, - userId, - stripeCustomerId, - agentId, - fetch, - logger, - insertMessageBigquery, - }) - : useCanopyWave - ? await handleCanopyWaveStream({ body: typedBody, userId, stripeCustomerId, @@ -556,8 +565,8 @@ export async function postChatCompletions(params: { logger, insertMessageBigquery, }) - : useFireworks - ? await handleFireworksStream({ + : useCanopyWave + ? await handleCanopyWaveStream({ body: typedBody, userId, stripeCustomerId, @@ -566,8 +575,8 @@ export async function postChatCompletions(params: { logger, insertMessageBigquery, }) - : useOpenAIDirect - ? await handleOpenAIStream({ + : useFireworks + ? await handleFireworksStream({ body: typedBody, userId, stripeCustomerId, @@ -576,16 +585,26 @@ export async function postChatCompletions(params: { logger, insertMessageBigquery, }) + : useOpenAIDirect + ? await handleOpenAIStream({ + body: typedBody, + userId, + stripeCustomerId, + agentId, + fetch, + logger, + insertMessageBigquery, + }) : await handleOpenRouterStream({ - body: typedBody, - userId, - stripeCustomerId, - agentId, - openrouterApiKey, - fetch, - logger, - insertMessageBigquery, - }) + body: typedBody, + userId, + stripeCustomerId, + agentId, + openrouterApiKey, + fetch, + logger, + insertMessageBigquery, + }) trackEvent({ event: AnalyticsEvent.CHAT_COMPLETIONS_STREAM_STARTED, @@ -616,16 +635,6 @@ export async function postChatCompletions(params: { const nonStreamRequest = useSiliconFlow ? handleSiliconFlowNonStream({ - body: typedBody, - userId, - stripeCustomerId, - agentId, - fetch, - logger, - insertMessageBigquery, - }) - : useCanopyWave - ? handleCanopyWaveNonStream({ body: typedBody, userId, stripeCustomerId, @@ -634,8 +643,8 @@ export async function postChatCompletions(params: { logger, insertMessageBigquery, }) - : useFireworks - ? handleFireworksNonStream({ + : useCanopyWave + ? handleCanopyWaveNonStream({ body: typedBody, userId, stripeCustomerId, @@ -644,8 +653,8 @@ export async function postChatCompletions(params: { logger, insertMessageBigquery, }) - : shouldUseOpenAIEndpoint - ? handleOpenAINonStream({ + : useFireworks + ? handleFireworksNonStream({ body: typedBody, userId, stripeCustomerId, @@ -654,16 +663,26 @@ export async function postChatCompletions(params: { logger, insertMessageBigquery, }) + : shouldUseOpenAIEndpoint + ? handleOpenAINonStream({ + body: typedBody, + userId, + stripeCustomerId, + agentId, + fetch, + logger, + insertMessageBigquery, + }) : handleOpenRouterNonStream({ - body: typedBody, - userId, - stripeCustomerId, - agentId, - openrouterApiKey, - fetch, - logger, - insertMessageBigquery, - }) + body: typedBody, + userId, + stripeCustomerId, + agentId, + openrouterApiKey, + fetch, + logger, + insertMessageBigquery, + }) const result = await nonStreamRequest trackEvent({ @@ -703,7 +722,15 @@ export async function postChatCompletions(params: { // Log detailed error information for debugging const errorDetails = openrouterError?.toJSON() - const providerLabel = siliconflowError ? 'SiliconFlow' : canopywaveError ? 'CanopyWave' : fireworksError ? 'Fireworks' : openaiError ? 'OpenAI' : 'OpenRouter' + const providerLabel = siliconflowError + ? 'SiliconFlow' + : canopywaveError + ? 'CanopyWave' + : fireworksError + ? 'Fireworks' + : openaiError + ? 'OpenAI' + : 'OpenRouter' logger.error( { error: getErrorObject(error), @@ -717,8 +744,20 @@ export async function postChatCompletions(params: { ? typedBody.messages.length : 0, messages: typedBody.messages, - providerStatusCode: (openrouterError ?? fireworksError ?? canopywaveError ?? siliconflowError ?? openaiError)?.statusCode, - providerStatusText: (openrouterError ?? fireworksError ?? canopywaveError ?? siliconflowError ?? openaiError)?.statusText, + providerStatusCode: ( + openrouterError ?? + fireworksError ?? + canopywaveError ?? + siliconflowError ?? + openaiError + )?.statusCode, + providerStatusText: ( + openrouterError ?? + fireworksError ?? + canopywaveError ?? + siliconflowError ?? + openaiError + )?.statusText, openrouterErrorCode: errorDetails?.error?.code, openrouterErrorType: errorDetails?.error?.type, openrouterErrorMessage: errorDetails?.error?.message, diff --git a/web/src/app/api/v1/freebuff/session/__tests__/session.test.ts b/web/src/app/api/v1/freebuff/session/__tests__/session.test.ts index 7ed29ec4b5..676dea44f8 100644 --- a/web/src/app/api/v1/freebuff/session/__tests__/session.test.ts +++ b/web/src/app/api/v1/freebuff/session/__tests__/session.test.ts @@ -17,12 +17,17 @@ const DEFAULT_MODEL = 'minimax/minimax-m2.7' function makeReq( apiKey: string | null, - opts: { instanceId?: string; cfCountry?: string; model?: string } = {}, + opts: { + instanceId?: string + cfCountry?: string | null + model?: string + } = {}, ): NextRequest { const headers = new Headers() if (apiKey) headers.set('Authorization', `Bearer ${apiKey}`) if (opts.instanceId) headers.set(FREEBUFF_INSTANCE_HEADER, opts.instanceId) - if (opts.cfCountry) headers.set('cf-ipcountry', opts.cfCountry) + const cfCountry = opts.cfCountry === null ? null : (opts.cfCountry ?? 'US') + if (cfCountry) headers.set('cf-ipcountry', cfCountry) if (opts.model) headers.set(FREEBUFF_MODEL_HEADER, opts.model) return { headers, @@ -107,19 +112,28 @@ function makeDeps( describe('POST /api/v1/freebuff/session', () => { test('401 when Authorization header is missing', async () => { const sessionDeps = makeSessionDeps() - const resp = await postFreebuffSession(makeReq(null), makeDeps(sessionDeps, null)) + const resp = await postFreebuffSession( + makeReq(null), + makeDeps(sessionDeps, null), + ) expect(resp.status).toBe(401) }) test('401 when API key is invalid', async () => { const sessionDeps = makeSessionDeps() - const resp = await postFreebuffSession(makeReq('bad'), makeDeps(sessionDeps, null)) + const resp = await postFreebuffSession( + makeReq('bad'), + makeDeps(sessionDeps, null), + ) expect(resp.status).toBe(401) }) test('creates a queued session for authed user', async () => { const sessionDeps = makeSessionDeps() - const resp = await postFreebuffSession(makeReq('ok'), makeDeps(sessionDeps, 'u1')) + const resp = await postFreebuffSession( + makeReq('ok'), + makeDeps(sessionDeps, 'u1'), + ) expect(resp.status).toBe(200) const body = await resp.json() expect(body.status).toBe('queued') @@ -128,7 +142,10 @@ describe('POST /api/v1/freebuff/session', () => { test('returns disabled when waiting room flag is off', async () => { const sessionDeps = makeSessionDeps({ isWaitingRoomEnabled: () => false }) - const resp = await postFreebuffSession(makeReq('ok'), makeDeps(sessionDeps, 'u1')) + const resp = await postFreebuffSession( + makeReq('ok'), + makeDeps(sessionDeps, 'u1'), + ) const body = await resp.json() expect(body.status).toBe('disabled') }) @@ -148,6 +165,32 @@ describe('POST /api/v1/freebuff/session', () => { expect(sessionDeps.rows.size).toBe(0) }) + test('returns country_blocked without joining the queue when country is unknown', async () => { + const sessionDeps = makeSessionDeps() + const resp = await postFreebuffSession( + makeReq('ok', { cfCountry: null }), + makeDeps(sessionDeps, 'u1'), + ) + expect(resp.status).toBe(403) + const body = await resp.json() + expect(body.status).toBe('country_blocked') + expect(body.countryCode).toBe('UNKNOWN') + expect(sessionDeps.rows.size).toBe(0) + }) + + test('returns country_blocked without joining the queue for anonymized Cloudflare country', async () => { + const sessionDeps = makeSessionDeps() + const resp = await postFreebuffSession( + makeReq('ok', { cfCountry: 'T1' }), + makeDeps(sessionDeps, 'u1'), + ) + expect(resp.status).toBe(403) + const body = await resp.json() + expect(body.status).toBe('country_blocked') + expect(body.countryCode).toBe('UNKNOWN') + expect(sessionDeps.rows.size).toBe(0) + }) + test('allows queue entry for allowed country', async () => { const sessionDeps = makeSessionDeps() const resp = await postFreebuffSession( @@ -191,7 +234,10 @@ describe('POST /api/v1/freebuff/session', () => { describe('GET /api/v1/freebuff/session', () => { test('returns { status: none } when user has no session', async () => { const sessionDeps = makeSessionDeps() - const resp = await getFreebuffSession(makeReq('ok'), makeDeps(sessionDeps, 'u1')) + const resp = await getFreebuffSession( + makeReq('ok'), + makeDeps(sessionDeps, 'u1'), + ) expect(resp.status).toBe(200) const body = await resp.json() expect(body.status).toBe('none') @@ -257,7 +303,10 @@ describe('DELETE /api/v1/freebuff/session', () => { created_at: new Date(), updated_at: new Date(), }) - const resp = await deleteFreebuffSession(makeReq('ok'), makeDeps(sessionDeps, 'u1')) + const resp = await deleteFreebuffSession( + makeReq('ok'), + makeDeps(sessionDeps, 'u1'), + ) expect(resp.status).toBe(200) expect(sessionDeps.rows.has('u1')).toBe(false) }) diff --git a/web/src/app/api/v1/freebuff/session/_handlers.ts b/web/src/app/api/v1/freebuff/session/_handlers.ts index 9a2d61899f..1ad7fea3c3 100644 --- a/web/src/app/api/v1/freebuff/session/_handlers.ts +++ b/web/src/app/api/v1/freebuff/session/_handlers.ts @@ -5,10 +5,7 @@ import { getSessionState, requestSession, } from '@/server/free-session/public-api' -import { - FREE_MODE_ALLOWED_COUNTRIES, - getCountryCode, -} from '@/server/free-mode-country' +import { getFreeModeCountryAccess } from '@/server/free-mode-country' import { extractApiKeyFromHeader } from '@/util/auth' import type { SessionDeps } from '@/server/free-session/public-api' @@ -16,22 +13,23 @@ import type { GetUserInfoFromApiKeyFn } from '@codebuff/common/types/contracts/d import type { Logger } from '@codebuff/common/types/contracts/logger' import type { NextRequest } from 'next/server' -/** Early country gate. Mirrors the chat/completions check: if we can resolve - * the caller's country and it's not on the allowlist, short-circuit with a - * terminal `country_blocked` response so the CLI can show the warning - * screen without ever joining the queue. Null country (VPN / localhost) - * fails open — chat/completions will catch it later if it matters. +/** Early country gate. Mirrors the chat/completions check: require a resolved + * allowlisted country before joining the queue. Unknown/anonymized locations + * are treated as blocked because they commonly indicate VPN, Tor, localhost, + * or proxy traffic. * * Returns HTTP 403 (not 200) so older CLIs — which don't know the * `country_blocked` status and would tight-poll on an unrecognized 200 * body — fall into their existing `!resp.ok` error path and back off on * the 10s error retry cadence. The new CLI parses the 403 body directly. */ function countryBlockedResponse(req: NextRequest): NextResponse | null { - const countryCode = getCountryCode(req) - if (!countryCode) return null - if (FREE_MODE_ALLOWED_COUNTRIES.has(countryCode)) return null + const countryAccess = getFreeModeCountryAccess(req) + if (countryAccess.allowed) return null return NextResponse.json( - { status: 'country_blocked', countryCode }, + { + status: 'country_blocked', + countryCode: countryAccess.countryCode ?? 'UNKNOWN', + }, { status: 403 }, ) } @@ -52,7 +50,10 @@ type AuthResult = | { error: NextResponse } | { userId: string; userEmail: string | null; userBanned: boolean } -async function resolveUser(req: NextRequest, deps: FreebuffSessionDeps): Promise { +async function resolveUser( + req: NextRequest, + deps: FreebuffSessionDeps, +): Promise { const apiKey = extractApiKeyFromHeader(req) if (!apiKey) { return { @@ -173,7 +174,8 @@ export async function getFreebuffSession( if (blocked) return blocked try { - const claimedInstanceId = req.headers.get(FREEBUFF_INSTANCE_HEADER) ?? undefined + const claimedInstanceId = + req.headers.get(FREEBUFF_INSTANCE_HEADER) ?? undefined const state = await getSessionState({ userId: auth.userId, userEmail: auth.userEmail, diff --git a/web/src/server/__tests__/free-mode-country.test.ts b/web/src/server/__tests__/free-mode-country.test.ts new file mode 100644 index 0000000000..db632b5ad0 --- /dev/null +++ b/web/src/server/__tests__/free-mode-country.test.ts @@ -0,0 +1,45 @@ +import { describe, expect, test } from 'bun:test' +import { NextRequest } from 'next/server' + +import { getFreeModeCountryAccess } from '../free-mode-country' + +function makeReq(headers: Record = {}): NextRequest { + return new NextRequest('http://localhost:3000/api/v1/chat/completions', { + headers, + }) +} + +describe('free mode country access', () => { + test('allows allowlisted Cloudflare countries', () => { + const access = getFreeModeCountryAccess(makeReq({ 'cf-ipcountry': 'us' })) + expect(access.allowed).toBe(true) + expect(access.countryCode).toBe('US') + expect(access.blockReason).toBe(null) + }) + + test('blocks countries outside the allowlist', () => { + const access = getFreeModeCountryAccess(makeReq({ 'cf-ipcountry': 'FR' })) + expect(access.allowed).toBe(false) + expect(access.countryCode).toBe('FR') + expect(access.blockReason).toBe('country_not_allowed') + }) + + test('blocks anonymized Cloudflare country codes without falling back to IP geo', () => { + const access = getFreeModeCountryAccess( + makeReq({ + 'cf-ipcountry': 'T1', + 'x-forwarded-for': '8.8.8.8', + }), + ) + expect(access.allowed).toBe(false) + expect(access.countryCode).toBe(null) + expect(access.blockReason).toBe('anonymized_or_unknown_country') + }) + + test('blocks missing client location as unknown', () => { + const access = getFreeModeCountryAccess(makeReq()) + expect(access.allowed).toBe(false) + expect(access.countryCode).toBe(null) + expect(access.blockReason).toBe('missing_client_ip') + }) +}) diff --git a/web/src/server/free-mode-country.ts b/web/src/server/free-mode-country.ts index 7936e3dcff..684511c9bc 100644 --- a/web/src/server/free-mode-country.ts +++ b/web/src/server/free-mode-country.ts @@ -3,11 +3,41 @@ import geoip from 'geoip-lite' import type { NextRequest } from 'next/server' export const FREE_MODE_ALLOWED_COUNTRIES = new Set([ - 'US', 'CA', - 'GB', 'AU', 'NZ', - 'NO', 'SE', 'NL', 'DK', 'DE', 'FI', 'BE', 'LU', 'CH', 'IE', 'IS', + 'US', + 'CA', + 'GB', + 'AU', + 'NZ', + 'NO', + 'SE', + 'NL', + 'DK', + 'DE', + 'FI', + 'BE', + 'LU', + 'CH', + 'IE', + 'IS', ]) +const CLOUDFLARE_ANONYMIZED_OR_UNKNOWN_COUNTRIES = new Set(['T1', 'XX']) + +export type FreeModeCountryBlockReason = + | 'country_not_allowed' + | 'anonymized_or_unknown_country' + | 'missing_client_ip' + | 'unresolved_client_ip' + +export type FreeModeCountryAccess = { + allowed: boolean + countryCode: string | null + blockReason: FreeModeCountryBlockReason | null + cfCountry: string | null + geoipCountry: string | null + hasClientIp: boolean +} + export function extractClientIp(req: NextRequest): string | undefined { const forwardedFor = req.headers.get('x-forwarded-for') if (forwardedFor) { @@ -16,28 +46,65 @@ export function extractClientIp(req: NextRequest): string | undefined { return req.headers.get('x-real-ip') ?? undefined } -export function getCountryCode(req: NextRequest): string | null { - const cfCountry = req.headers.get('cf-ipcountry') - if (cfCountry && cfCountry !== 'XX' && cfCountry !== 'T1') { - return cfCountry.toUpperCase() +export function getFreeModeCountryAccess( + req: NextRequest, +): FreeModeCountryAccess { + const cfCountry = req.headers.get('cf-ipcountry')?.toUpperCase() ?? null + const clientIp = extractClientIp(req) + + if (cfCountry && CLOUDFLARE_ANONYMIZED_OR_UNKNOWN_COUNTRIES.has(cfCountry)) { + return { + allowed: false, + countryCode: null, + blockReason: 'anonymized_or_unknown_country', + cfCountry, + geoipCountry: null, + hasClientIp: Boolean(clientIp), + } + } + + if (cfCountry) { + const allowed = FREE_MODE_ALLOWED_COUNTRIES.has(cfCountry) + return { + allowed, + countryCode: cfCountry, + blockReason: allowed ? null : 'country_not_allowed', + cfCountry, + geoipCountry: null, + hasClientIp: Boolean(clientIp), + } } - const clientIp = extractClientIp(req) if (!clientIp) { - return null + return { + allowed: false, + countryCode: null, + blockReason: 'missing_client_ip', + cfCountry: null, + geoipCountry: null, + hasClientIp: false, + } } - const geo = geoip.lookup(clientIp) - return geo?.country ?? null -} -/** - * Returns true if the request's resolved country is allowed to use free - * mode, false if it's explicitly disallowed. Returns null when country can't - * be determined (VPN / localhost / corporate proxy) — callers should fail - * open in that case to match the chat-completions gate. - */ -export function isCountryAllowedForFreeMode(req: NextRequest): boolean | null { - const countryCode = getCountryCode(req) - if (!countryCode) return null - return FREE_MODE_ALLOWED_COUNTRIES.has(countryCode) + const geoipCountry = geoip.lookup(clientIp)?.country ?? null + if (!geoipCountry) { + return { + allowed: false, + countryCode: null, + blockReason: 'unresolved_client_ip', + cfCountry: null, + geoipCountry: null, + hasClientIp: true, + } + } + + const allowed = FREE_MODE_ALLOWED_COUNTRIES.has(geoipCountry) + return { + allowed, + countryCode: geoipCountry, + blockReason: allowed ? null : 'country_not_allowed', + cfCountry: null, + geoipCountry, + hasClientIp: true, + } }