From a8886896bbac17ffe1e78b396a21063ebfeeaee1 Mon Sep 17 00:00:00 2001 From: gwizz Date: Thu, 15 Jan 2026 10:38:19 +1100 Subject: [PATCH 1/8] feat(auth): oauth multi-account failover --- packages/opencode/src/auth/context.ts | 21 + packages/opencode/src/auth/index.ts | 451 +++++++++++++++++- packages/opencode/src/auth/rotating-fetch.ts | 232 +++++++++ packages/opencode/src/cli/cmd/auth.ts | 33 +- packages/opencode/src/provider/auth.ts | 13 +- packages/opencode/src/provider/provider.ts | 9 +- .../opencode/test/auth/oauth-rotation.test.ts | 127 +++++ packages/plugin/src/index.ts | 2 + 8 files changed, 840 insertions(+), 48 deletions(-) create mode 100644 packages/opencode/src/auth/context.ts create mode 100644 packages/opencode/src/auth/rotating-fetch.ts create mode 100644 packages/opencode/test/auth/oauth-rotation.test.ts diff --git a/packages/opencode/src/auth/context.ts b/packages/opencode/src/auth/context.ts new file mode 100644 index 000000000000..f3f6efbdbefc --- /dev/null +++ b/packages/opencode/src/auth/context.ts @@ -0,0 +1,21 @@ +import { AsyncLocalStorage } from "async_hooks" + +type Store = { + oauthRecordByProvider: Map +} + +const storage = new AsyncLocalStorage() + +export function getOAuthRecordID(providerID: string): string | undefined { + return storage.getStore()?.oauthRecordByProvider.get(providerID) +} + +export function withOAuthRecord(providerID: string, recordID: string, fn: () => T): T { + const current = storage.getStore() + const next: Store = { + oauthRecordByProvider: new Map(current?.oauthRecordByProvider ?? []), + } + next.oauthRecordByProvider.set(providerID, recordID) + + return storage.run(next, fn) +} diff --git a/packages/opencode/src/auth/index.ts b/packages/opencode/src/auth/index.ts index 3fd28305368e..6dc6e7fa456d 100644 --- a/packages/opencode/src/auth/index.ts +++ b/packages/opencode/src/auth/index.ts @@ -2,6 +2,8 @@ import path from "path" import { Global } from "../global" import fs from "fs/promises" import z from "zod" +import { ulid } from "ulid" +import { getOAuthRecordID } from "./context" export const OAUTH_DUMMY_KEY = "opencode-oauth-dummy-key" @@ -37,37 +39,440 @@ export namespace Auth { const filepath = path.join(Global.Path.data, "auth.json") - export async function get(providerID: string) { - const auth = await all() - return auth[providerID] + const Health = z + .object({ + cooldownUntil: z.number().optional(), + lastStatusCode: z.number().optional(), + lastErrorAt: z.number().optional(), + successCount: z.number().default(0), + failureCount: z.number().default(0), + }) + .strict() + .default(() => ({ successCount: 0, failureCount: 0 })) + type Health = z.infer + + const OAuthRecord = z + .object({ + id: z.string(), + namespace: z.string().default("default"), + label: z.string().optional(), + accountId: z.string().optional(), + enterpriseUrl: z.string().optional(), + refresh: z.string(), + access: z.string(), + expires: z.number(), + createdAt: z.number(), + updatedAt: z.number(), + health: Health, + }) + .strict() + type OAuthRecord = z.infer + + export type OAuthRecordMeta = Omit + + const OAuthProvider = z + .object({ + type: z.literal("oauth"), + active: z.record(z.string(), z.string()).default({}), + order: z.record(z.string(), z.array(z.string())).default({}), + records: z.array(OAuthRecord).default([]), + }) + .strict() + type OAuthProvider = z.infer + + const ApiProvider = z + .object({ + type: z.literal("api"), + key: z.string(), + }) + .strict() + + const WellKnownProvider = z + .object({ + type: z.literal("wellknown"), + key: z.string(), + token: z.string(), + }) + .strict() + + const ProviderEntry = z.union([OAuthProvider, ApiProvider, WellKnownProvider]) + type ProviderEntry = z.infer + + const StoreFile = z + .object({ + version: z.literal(2), + providers: z.record(z.string(), ProviderEntry).default({}), + }) + .strict() + type StoreFile = z.infer + + function toMeta(record: OAuthRecord): OAuthRecordMeta { + const { refresh: _refresh, access: _access, expires: _expires, ...meta } = record + return meta } - export async function all(): Promise> { + async function ensureDataDir(): Promise { + await fs.mkdir(path.dirname(filepath), { recursive: true }) + } + + async function writeStoreFile(store: StoreFile): Promise { + await ensureDataDir() + const tempPath = `${filepath}.tmp` + const tempFile = Bun.file(tempPath) + await Bun.write(tempFile, JSON.stringify(store, null, 2)) + await fs.rename(tempPath, filepath) + await fs.chmod(filepath, 0o600).catch(() => {}) + } + + async function loadStoreFile(): Promise { const file = Bun.file(filepath) - const data = await file.json().catch(() => ({}) as Record) - return Object.entries(data).reduce( - (acc, [key, value]) => { - const parsed = Info.safeParse(value) - if (!parsed.success) return acc - acc[key] = parsed.data - return acc - }, - {} as Record, - ) + const raw = await file.json().catch(() => undefined) + + const parsed = StoreFile.safeParse(raw) + if (parsed.success) return parsed.data + + const legacyParsed = z.record(z.string(), Info).safeParse(raw) + if (legacyParsed.success) { + const now = Date.now() + const next: StoreFile = { version: 2, providers: {} } + + for (const [providerID, info] of Object.entries(legacyParsed.data)) { + if (info.type === "api") { + next.providers[providerID] = { type: "api", key: info.key } + continue + } + + if (info.type === "wellknown") { + next.providers[providerID] = { type: "wellknown", key: info.key, token: info.token } + continue + } + + const recordID = ulid() + next.providers[providerID] = { + type: "oauth", + active: { default: recordID }, + order: { default: [recordID] }, + records: [ + { + id: recordID, + namespace: "default", + label: "default", + accountId: info.accountId, + enterpriseUrl: info.enterpriseUrl, + refresh: info.refresh, + access: info.access, + expires: info.expires, + createdAt: now, + updatedAt: now, + health: { successCount: 0, failureCount: 0 }, + }, + ], + } + } + + await writeStoreFile(next) + return next + } + + return { version: 2, providers: {} } + } + + function ensureOAuthProvider(store: StoreFile, providerID: string): OAuthProvider { + const existing = store.providers[providerID] + if (existing && existing.type === "oauth") return existing + + const next: OAuthProvider = { + type: "oauth", + active: {}, + order: {}, + records: [], + } + store.providers[providerID] = next + return next + } + + function findOAuthRecord(provider: OAuthProvider, recordID: string): OAuthRecord | undefined { + return provider.records.find((record) => record.id === recordID) + } + + function normalizeOrder(ids: string[], order: string[]): string[] { + const ordered: string[] = [] + for (const id of order) { + if (ids.includes(id) && !ordered.includes(id)) ordered.push(id) + } + for (const id of ids) { + if (!ordered.includes(id)) ordered.push(id) + } + return ordered + } + + function recordIDsForNamespace(provider: OAuthProvider, namespace: string): string[] { + const ids = provider.records.filter((record) => record.namespace === namespace).map((record) => record.id) + const order = provider.order[namespace] ?? [] + return normalizeOrder(ids, order) + } + + async function findOAuthRecordIDByRefreshToken(input: { + providerID: string + namespace: string + refresh: string + provider: OAuthProvider + }): Promise { + for (const record of input.provider.records) { + if (record.namespace !== input.namespace) continue + if (record.refresh === input.refresh) return record.id + } + return undefined + } + + export async function get(providerID: string): Promise { + const store = await loadStoreFile() + const entry = store.providers[providerID] + if (!entry) return undefined + + if (entry.type === "api") { + return { type: "api", key: entry.key } + } + + if (entry.type === "wellknown") { + return { type: "wellknown", key: entry.key, token: entry.token } + } + + const namespace = "default" + const contextID = getOAuthRecordID(providerID) + const active = contextID ?? entry.active[namespace] + const ordered = recordIDsForNamespace(entry, namespace) + const recordID = active && ordered.includes(active) ? active : ordered[0] + if (!recordID) return undefined + + const record = findOAuthRecord(entry, recordID) + if (!record) return undefined + return { + type: "oauth", + refresh: record.refresh, + access: record.access, + expires: record.expires, + accountId: record.accountId, + enterpriseUrl: record.enterpriseUrl, + } + } + + export async function all(): Promise> { + const store = await loadStoreFile() + const out: Record = {} + + for (const providerID of Object.keys(store.providers)) { + const info = await get(providerID) + if (!info) continue + out[providerID] = info + } + + return out } export async function set(key: string, info: Info) { - const file = Bun.file(filepath) - const data = await all() - await Bun.write(file, JSON.stringify({ ...data, [key]: info }, null, 2)) - await fs.chmod(file.name!, 0o600) + const store = await loadStoreFile() + + if (info.type === "api") { + store.providers[key] = { type: "api", key: info.key } + await writeStoreFile(store) + return + } + + if (info.type === "wellknown") { + store.providers[key] = { type: "wellknown", key: info.key, token: info.token } + await writeStoreFile(store) + return + } + + const namespace = "default" + const provider = ensureOAuthProvider(store, key) + const recordID = + getOAuthRecordID(key) ?? + (await findOAuthRecordIDByRefreshToken({ providerID: key, namespace, refresh: info.refresh, provider })) ?? + provider.active[namespace] ?? + recordIDsForNamespace(provider, namespace)[0] ?? + ulid() + + const now = Date.now() + const existing = findOAuthRecord(provider, recordID) + if (!existing) { + provider.records.push({ + id: recordID, + namespace, + label: "default", + accountId: info.accountId, + enterpriseUrl: info.enterpriseUrl, + refresh: info.refresh, + access: info.access, + expires: info.expires, + createdAt: now, + updatedAt: now, + health: { successCount: 0, failureCount: 0 }, + }) + provider.order[namespace] = [...(provider.order[namespace] ?? []), recordID] + } else { + existing.refresh = info.refresh + existing.access = info.access + existing.expires = info.expires + existing.updatedAt = now + if (info.accountId !== undefined) existing.accountId = info.accountId + if (info.enterpriseUrl !== undefined) existing.enterpriseUrl = info.enterpriseUrl + const order = provider.order[namespace] ?? [] + if (!order.includes(recordID)) { + provider.order[namespace] = [...order, recordID] + } + } + provider.active[namespace] = recordID + + await writeStoreFile(store) } export async function remove(key: string) { - const file = Bun.file(filepath) - const data = await all() - delete data[key] - await Bun.write(file, JSON.stringify(data, null, 2)) - await fs.chmod(file.name!, 0o600) + const store = await loadStoreFile() + const existing = store.providers[key] + if (!existing) return + + delete store.providers[key] + await writeStoreFile(store) + } + + export async function addOAuth( + providerID: string, + input: Omit, "type"> & { namespace?: string; label?: string }, + ) { + const namespace = (input.namespace ?? "default").trim() || "default" + const store = await loadStoreFile() + + const provider = ensureOAuthProvider(store, providerID) + const now = Date.now() + const existingRecordID = await findOAuthRecordIDByRefreshToken({ + providerID, + namespace, + refresh: input.refresh, + provider, + }) + + if (existingRecordID) { + const existing = findOAuthRecord(provider, existingRecordID) + if (existing) { + existing.refresh = input.refresh + existing.access = input.access + existing.expires = input.expires + existing.updatedAt = now + if (input.accountId !== undefined) existing.accountId = input.accountId + if (input.enterpriseUrl !== undefined) existing.enterpriseUrl = input.enterpriseUrl + if (input.label) existing.label = input.label + } + const order = provider.order[namespace] ?? [] + if (!order.includes(existingRecordID)) { + provider.order[namespace] = [...order, existingRecordID] + } + provider.active[namespace] = existingRecordID + + await writeStoreFile(store) + return { providerID, namespace, recordID: existingRecordID } + } + + const recordID = ulid() + + provider.records.push({ + id: recordID, + namespace, + label: input.label ?? "default", + accountId: input.accountId, + enterpriseUrl: input.enterpriseUrl, + refresh: input.refresh, + access: input.access, + expires: input.expires, + createdAt: now, + updatedAt: now, + health: { successCount: 0, failureCount: 0 }, + }) + + provider.order[namespace] = [...(provider.order[namespace] ?? []), recordID] + provider.active[namespace] = recordID + + await writeStoreFile(store) + + return { providerID, namespace, recordID } + } + + export namespace OAuthPool { + export async function snapshot( + providerID: string, + namespace = "default", + ): Promise<{ records: OAuthRecordMeta[]; orderedIDs: string[] }> { + const store = await loadStoreFile() + const provider = store.providers[providerID] + if (!provider || provider.type !== "oauth") return { records: [], orderedIDs: [] } + + const normalized = namespace.trim() || "default" + const records = provider.records.filter((record) => record.namespace === normalized).map(toMeta) + const orderedIDs = recordIDsForNamespace(provider, normalized) + + return { records, orderedIDs } + } + + export async function list(providerID: string, namespace = "default"): Promise { + return snapshot(providerID, namespace).then((result) => result.records) + } + + export async function orderedIDs(providerID: string, namespace = "default"): Promise { + return snapshot(providerID, namespace).then((result) => result.orderedIDs) + } + + export async function moveToBack(providerID: string, namespace: string, recordID: string): Promise { + const store = await loadStoreFile() + const provider = store.providers[providerID] + if (!provider || provider.type !== "oauth") return + const order = recordIDsForNamespace(provider, namespace) + provider.order[namespace] = order.filter((id) => id !== recordID).concat(recordID) + provider.active[namespace] = provider.order[namespace][0] ?? provider.active[namespace] + await writeStoreFile(store) + } + + export async function recordOutcome(input: { + providerID: string + recordID: string + statusCode: number + ok: boolean + cooldownUntil?: number + }): Promise { + const store = await loadStoreFile() + const provider = store.providers[input.providerID] + if (!provider || provider.type !== "oauth") return + + const record = findOAuthRecord(provider, input.recordID) + if (!record) return + + const now = Date.now() + const prevCooldown = + record.health.cooldownUntil && record.health.cooldownUntil > now ? record.health.cooldownUntil : undefined + const cooldownUntil = input.ok ? undefined : input.cooldownUntil ?? prevCooldown + + record.health = { + ...record.health, + cooldownUntil, + lastStatusCode: input.statusCode, + lastErrorAt: input.ok ? undefined : now, + successCount: record.health.successCount + (input.ok ? 1 : 0), + failureCount: record.health.failureCount + (input.ok ? 0 : 1), + } + record.updatedAt = now + await writeStoreFile(store) + } + + export async function markAccessExpired(providerID: string, namespace: string, recordID: string): Promise { + const store = await loadStoreFile() + const provider = store.providers[providerID] + if (!provider || provider.type !== "oauth") return + const record = findOAuthRecord(provider, recordID) + if (!record || record.namespace !== namespace) return + record.access = "" + record.expires = 0 + record.updatedAt = Date.now() + await writeStoreFile(store) + } } } diff --git a/packages/opencode/src/auth/rotating-fetch.ts b/packages/opencode/src/auth/rotating-fetch.ts new file mode 100644 index 000000000000..dcc451b179d2 --- /dev/null +++ b/packages/opencode/src/auth/rotating-fetch.ts @@ -0,0 +1,232 @@ +import { Auth } from "./index" +import { withOAuthRecord } from "./context" + +const DEFAULT_RATE_LIMIT_COOLDOWN_MS = 30_000 + +function isReadableStream(value: unknown): value is ReadableStream { + return typeof ReadableStream !== "undefined" && value instanceof ReadableStream +} + +function isAsyncIterable(value: unknown): boolean { + return typeof value === "object" && value !== null && Symbol.asyncIterator in value +} + +function isReplayableBody(body: unknown): boolean { + if (!body) return true + if (isReadableStream(body)) return false + if (isAsyncIterable(body)) return false + return true +} + +function isRequest(value: unknown): value is Request { + return typeof Request !== "undefined" && value instanceof Request +} + +async function drainResponse(response: Response): Promise { + try { + await response.body?.cancel() + } catch {} +} + +function parseRetryAfterMs(response: Response): number | undefined { + const value = response.headers.get("retry-after") ?? response.headers.get("Retry-After") + if (!value) return undefined + + const seconds = Number(value) + if (Number.isFinite(seconds)) return Math.max(0, seconds) * 1000 + + const dateMs = Date.parse(value) + if (!Number.isNaN(dateMs)) return Math.max(0, dateMs - Date.now()) + + return undefined +} + +function isAuthExpiredStatus(status: number): boolean { + return status === 401 || status === 403 +} + +export function createOAuthRotatingFetch Promise>( + fetchFn: TFetch, + opts: { + providerID: string + namespace?: string + maxAttempts?: number + }, +): TFetch { + const namespace = (opts.namespace ?? "default").trim() || "default" + + return (async (input: any, init?: any) => { + const { records, orderedIDs } = await Auth.OAuthPool.snapshot(opts.providerID, namespace) + if (records.length === 0) return fetchFn(input, init) + + if (orderedIDs.length <= 1) return fetchFn(input, init) + + const recordByID = new Map(records.map((record) => [record.id, record])) + const candidates = orderedIDs.filter((id) => recordByID.has(id)) + if (candidates.length === 0) return fetchFn(input, init) + const inputIsRequest = isRequest(input) + let allowRetry = + isReplayableBody(init?.body) && (!inputIsRequest || (!input.bodyUsed && !isReadableStream(input.body))) + + let maxAttempts = Math.max(1, opts.maxAttempts ?? candidates.length) + if (!allowRetry) { + maxAttempts = 1 + } else if (maxAttempts > candidates.length) { + maxAttempts = candidates.length + } + + const attempted = new Set() + const refreshed = new Set() + let lastError: unknown + + for (let attempt = 0; attempt < maxAttempts; attempt++) { + const now = Date.now() + + const nextID = + candidates.find((id) => { + if (attempted.has(id)) return false + const cooldownUntil = recordByID.get(id)?.health.cooldownUntil + return !cooldownUntil || cooldownUntil <= now + }) ?? candidates.find((id) => !attempted.has(id)) + + if (!nextID) break + attempted.add(nextID) + + let attemptInput = input + if (inputIsRequest && allowRetry) { + try { + attemptInput = input.clone() + } catch (e) { + lastError = e + allowRetry = false + maxAttempts = attempt + 1 + } + } + + const hasMoreAttempts = attempt + 1 < maxAttempts + + const run = () => withOAuthRecord(opts.providerID, nextID, () => fetchFn(attemptInput, init)) + + let response: Response + try { + response = await run() + } catch (e) { + lastError = e + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: 0, + ok: false, + }) + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + if (!hasMoreAttempts) throw e + continue + } + + if (response.ok) { + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: response.status, + ok: true, + }) + return response + } + + if (response.status === 429) { + const cooldownMs = parseRetryAfterMs(response) ?? DEFAULT_RATE_LIMIT_COOLDOWN_MS + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: response.status, + ok: false, + cooldownUntil: Date.now() + cooldownMs, + }) + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + if (!hasMoreAttempts) return response + await drainResponse(response) + continue + } + + if (isAuthExpiredStatus(response.status) && !refreshed.has(nextID)) { + refreshed.add(nextID) + + await Auth.OAuthPool.markAccessExpired(opts.providerID, namespace, nextID) + if (!allowRetry) { + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: response.status, + ok: false, + }) + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + return response + } + + await drainResponse(response) + + try { + const retry = await run() + if (retry.ok) { + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: retry.status, + ok: true, + }) + return retry + } + + if (retry.status === 429) { + const cooldownMs = parseRetryAfterMs(retry) ?? DEFAULT_RATE_LIMIT_COOLDOWN_MS + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: retry.status, + ok: false, + cooldownUntil: Date.now() + cooldownMs, + }) + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + if (!hasMoreAttempts) return retry + await drainResponse(retry) + continue + } + + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: retry.status, + ok: false, + }) + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + if (!hasMoreAttempts) return retry + await drainResponse(retry) + continue + } catch (e) { + lastError = e + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: 0, + ok: false, + }) + if (!hasMoreAttempts) throw e + } + + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + continue + } + + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: response.status, + ok: false, + }) + return response + } + + if (lastError) throw lastError + return fetchFn(input, init) + }) as TFetch +} diff --git a/packages/opencode/src/cli/cmd/auth.ts b/packages/opencode/src/cli/cmd/auth.ts index bbaecfd8c711..5f379fc4fd1c 100644 --- a/packages/opencode/src/cli/cmd/auth.ts +++ b/packages/opencode/src/cli/cmd/auth.ts @@ -82,13 +82,12 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): if (result.type === "success") { const saveProvider = result.provider ?? provider if ("refresh" in result) { - const { type: _, provider: __, refresh, access, expires, ...extraFields } = result - await Auth.set(saveProvider, { - type: "oauth", - refresh, - access, - expires, - ...extraFields, + await Auth.addOAuth(saveProvider, { + refresh: result.refresh, + access: result.access, + expires: result.expires, + accountId: result.accountId, + enterpriseUrl: result.enterpriseUrl, }) } if ("key" in result) { @@ -114,13 +113,12 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): if (result.type === "success") { const saveProvider = result.provider ?? provider if ("refresh" in result) { - const { type: _, provider: __, refresh, access, expires, ...extraFields } = result - await Auth.set(saveProvider, { - type: "oauth", - refresh, - access, - expires, - ...extraFields, + await Auth.addOAuth(saveProvider, { + refresh: result.refresh, + access: result.access, + expires: result.expires, + accountId: result.accountId, + enterpriseUrl: result.enterpriseUrl, }) } if ("key" in result) { @@ -182,7 +180,12 @@ export const AuthListCommand = cmd({ for (const [providerID, result] of results) { const name = database[providerID]?.name || providerID - prompts.log.info(`${name} ${UI.Style.TEXT_DIM}${result.type}`) + if (result.type === "oauth") { + const count = await Auth.OAuthPool.list(providerID).then((accounts) => accounts.length) + prompts.log.info(`${name} ${UI.Style.TEXT_DIM}oauth${count > 1 ? ` (${count} accounts)` : ""}`) + } else { + prompts.log.info(`${name} ${UI.Style.TEXT_DIM}${result.type}`) + } } prompts.outro(`${results.length} credentials`) diff --git a/packages/opencode/src/provider/auth.ts b/packages/opencode/src/provider/auth.ts index e6681ff08914..283957b09b2b 100644 --- a/packages/opencode/src/provider/auth.ts +++ b/packages/opencode/src/provider/auth.ts @@ -99,16 +99,13 @@ export namespace ProviderAuth { }) } if ("refresh" in result) { - const info: Auth.Info = { - type: "oauth", - access: result.access, + await Auth.addOAuth(input.providerID, { refresh: result.refresh, + access: result.access, expires: result.expires, - } - if (result.accountId) { - info.accountId = result.accountId - } - await Auth.set(input.providerID, info) + accountId: result.accountId, + enterpriseUrl: result.enterpriseUrl, + }) } return } diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 69946afd83a7..994725c03d85 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -9,6 +9,7 @@ import { Plugin } from "../plugin" import { ModelsDev } from "./models" import { NamedError } from "@opencode-ai/util/error" import { Auth } from "../auth" +import { createOAuthRotatingFetch } from "../auth/rotating-fetch" import { Env } from "../env" import { Instance } from "../project/instance" import { Flag } from "../flag/flag" @@ -976,13 +977,13 @@ export namespace Provider { ...model.headers, } - const key = Bun.hash.xxHash32(JSON.stringify({ npm: model.api.npm, options })) + const key = Bun.hash.xxHash32(JSON.stringify({ providerID: model.providerID, npm: model.api.npm, options })) const existing = s.sdk.get(key) if (existing) return existing const customFetch = options["fetch"] - options["fetch"] = async (input: any, init?: BunFetchRequestInit) => { + const fetchWithTimeout = async (input: any, init?: BunFetchRequestInit) => { // Preserve custom fetch if it exists, wrap it with timeout logic const fetchFn = customFetch ?? fetch const opts = init ?? {} @@ -1004,6 +1005,10 @@ export namespace Provider { }) } + options["fetch"] = createOAuthRotatingFetch(fetchWithTimeout, { + providerID: model.providerID, + }) + // Special case: google-vertex-anthropic uses a subpath import const bundledKey = model.providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : model.api.npm diff --git a/packages/opencode/test/auth/oauth-rotation.test.ts b/packages/opencode/test/auth/oauth-rotation.test.ts new file mode 100644 index 000000000000..357e0494f8da --- /dev/null +++ b/packages/opencode/test/auth/oauth-rotation.test.ts @@ -0,0 +1,127 @@ +import { describe, expect, test } from "bun:test" +import { Auth } from "../../src/auth" +import { createOAuthRotatingFetch } from "../../src/auth/rotating-fetch" +import { withOAuthRecord } from "../../src/auth/context" + +describe("OAuth subscription failover", () => { + const providerID = "oauth-rotation-test" + + test("rotates on 429 (Retry-After) and succeeds with next account", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + if (auth.refresh === "r1") { + return new Response("rate limited", { + status: 429, + headers: { + "Retry-After": "1", + }, + }) + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + + expect(response.status).toBe(200) + + const order = await Auth.OAuthPool.orderedIDs(providerID) + expect(order[0]).toBe(a2.recordID) + expect(order[1]).toBe(a1.recordID) + }) + + test("updates the correct OAuth record by refresh token without record context", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + await Auth.set(providerID, { + type: "oauth", + refresh: "r1", + access: "updated-a1", + expires: Date.now() + 60_000, + }) + + const record1 = await withOAuthRecord(providerID, a1.recordID, async () => Auth.get(providerID)) + const record2 = await withOAuthRecord(providerID, a2.recordID, async () => Auth.get(providerID)) + + expect(record1?.type).toBe("oauth") + expect(record1 && record1.type === "oauth" ? record1.access : "").toBe("updated-a1") + + expect(record2?.type).toBe("oauth") + expect(record2 && record2.type === "oauth" ? record2.access : "").toBe("a2") + }) + + test("retries once on 401/403 by forcing refresh, then succeeds", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "bad", + expires: Date.now() + 60_000, + }) + await Auth.addOAuth(providerID, { + refresh: "r2", + access: "ok", + expires: Date.now() + 60_000, + }) + + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + // Simulate plugin refresh behavior: when access is cleared/expired, + // it refreshes and persists via Auth.set(). + if (!auth.access) { + await Auth.set(providerID, { + type: "oauth", + refresh: auth.refresh, + access: `refreshed-${auth.refresh}`, + expires: Date.now() + 60_000, + }) + return new Response("ok", { status: 200 }) + } + + if (auth.access === "bad") { + return new Response("unauthorized", { status: 401 }) + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + expect(response.status).toBe(200) + + const record1 = await withOAuthRecord(providerID, a1.recordID, async () => Auth.get(providerID)) + expect(record1?.type).toBe("oauth") + expect(record1 && record1.type === "oauth" ? record1.access : "").toBe("refreshed-r1") + }) +}) diff --git a/packages/plugin/src/index.ts b/packages/plugin/src/index.ts index e57eff579e63..712193bd8e40 100644 --- a/packages/plugin/src/index.ts +++ b/packages/plugin/src/index.ts @@ -115,6 +115,7 @@ export type AuthOuathResult = { url: string; instructions: string } & ( access: string expires: number accountId?: string + enterpriseUrl?: string } | { key: string } )) @@ -135,6 +136,7 @@ export type AuthOuathResult = { url: string; instructions: string } & ( access: string expires: number accountId?: string + enterpriseUrl?: string } | { key: string } )) From 87be01b4c111a3db0d8567e5eb967dc8c964fe01 Mon Sep 17 00:00:00 2001 From: gwizz Date: Thu, 15 Jan 2026 12:56:50 +1100 Subject: [PATCH 2/8] Add OAuth credential failover with toasts --- .../opencode/src/auth/credential-manager.ts | 56 ++++++++++++ packages/opencode/src/auth/rotating-fetch.ts | 41 +++++++-- .../opencode/test/auth/oauth-rotation.test.ts | 86 +++++++++++++++++++ 3 files changed, 177 insertions(+), 6 deletions(-) create mode 100644 packages/opencode/src/auth/credential-manager.ts diff --git a/packages/opencode/src/auth/credential-manager.ts b/packages/opencode/src/auth/credential-manager.ts new file mode 100644 index 000000000000..e244bc576bd5 --- /dev/null +++ b/packages/opencode/src/auth/credential-manager.ts @@ -0,0 +1,56 @@ +import z from "zod" +import { Bus } from "../bus" +import { BusEvent } from "../bus/bus-event" +import { Log } from "../util/log" +import { TuiEvent } from "../cli/cmd/tui/event" + +const log = Log.create({ service: "credential-manager" }) + +export namespace CredentialManager { + export const Event = { + Failover: BusEvent.define( + "credential.failover", + z.object({ + providerID: z.string(), + fromRecordID: z.string(), + toRecordID: z.string().optional(), + statusCode: z.number(), + message: z.string(), + }), + ), + } + + export async function notifyFailover(input: { + providerID: string + fromRecordID: string + toRecordID?: string + statusCode: number + }): Promise { + const isRateLimit = input.statusCode === 429 + const message = isRateLimit + ? `Rate limited on "${input.providerID}". Switching OAuth credential...` + : `Auth error on "${input.providerID}". Switching OAuth credential...` + + log.info("oauth credential failover", { + providerID: input.providerID, + fromRecordID: input.fromRecordID, + toRecordID: input.toRecordID, + statusCode: input.statusCode, + }) + + await Bus.publish(Event.Failover, { + providerID: input.providerID, + fromRecordID: input.fromRecordID, + toRecordID: input.toRecordID, + statusCode: input.statusCode, + message, + }).catch((error) => log.debug("failed to publish credential failover event", { error })) + + await Bus.publish(TuiEvent.ToastShow, { + title: "OAuth Credential Failover", + message, + variant: "warning", + duration: 5000, + }).catch((error) => log.debug("failed to show failover toast", { error })) + } +} diff --git a/packages/opencode/src/auth/rotating-fetch.ts b/packages/opencode/src/auth/rotating-fetch.ts index dcc451b179d2..ce06afaf1c3d 100644 --- a/packages/opencode/src/auth/rotating-fetch.ts +++ b/packages/opencode/src/auth/rotating-fetch.ts @@ -1,7 +1,9 @@ import { Auth } from "./index" import { withOAuthRecord } from "./context" +import { CredentialManager } from "./credential-manager" const DEFAULT_RATE_LIMIT_COOLDOWN_MS = 30_000 +const DEFAULT_AUTH_FAILURE_COOLDOWN_MS = 5 * 60_000 function isReadableStream(value: unknown): value is ReadableStream { return typeof ReadableStream !== "undefined" && value instanceof ReadableStream @@ -59,8 +61,6 @@ export function createOAuthRotatingFetch [record.id, record])) const candidates = orderedIDs.filter((id) => recordByID.has(id)) if (candidates.length === 0) return fetchFn(input, init) @@ -79,6 +79,13 @@ export function createOAuthRotatingFetch() let lastError: unknown + const pickNextCandidate = (now: number) => + candidates.find((id) => { + if (attempted.has(id)) return false + const cooldownUntil = recordByID.get(id)?.health.cooldownUntil + return !cooldownUntil || cooldownUntil <= now + }) ?? candidates.find((id) => !attempted.has(id)) + for (let attempt = 0; attempt < maxAttempts; attempt++) { const now = Date.now() @@ -143,6 +150,15 @@ export function createOAuthRotatingFetch { expect(record1?.type).toBe("oauth") expect(record1 && record1.type === "oauth" ? record1.access : "").toBe("refreshed-r1") }) + + test("fails over on 401/403 when refresh does not fix the credential", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "bad", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "ok", + expires: Date.now() + 60_000, + }) + + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + if (auth.refresh === "r1") { + return new Response("unauthorized", { status: 401 }) + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + + expect(response.status).toBe(200) + + const order = await Auth.OAuthPool.orderedIDs(providerID) + expect(order[0]).toBe(a2.recordID) + expect(order[1]).toBe(a1.recordID) + }) + + test("sticks to the active credential until rate limited", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const counts = new Map() + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + const refresh = auth.refresh + counts.set(refresh, (counts.get(refresh) ?? 0) + 1) + + if (refresh === "r1" && (counts.get(refresh) ?? 0) >= 3) { + return new Response("rate limited", { status: 429 }) + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + + const first = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + expect(first.status).toBe(200) + + const second = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + expect(second.status).toBe(200) + + const beforeRateLimit = await Auth.OAuthPool.orderedIDs(providerID) + expect(beforeRateLimit[0]).toBe(a1.recordID) + expect(beforeRateLimit[1]).toBe(a2.recordID) + + const third = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + expect(third.status).toBe(200) + + const afterRateLimit = await Auth.OAuthPool.orderedIDs(providerID) + expect(afterRateLimit[0]).toBe(a2.recordID) + expect(afterRateLimit[1]).toBe(a1.recordID) + }) }) From da789e2c3ccf8e38028f3aa8bdc9f2afe3bbdafb Mon Sep 17 00:00:00 2001 From: gwizz Date: Thu, 15 Jan 2026 13:28:06 +1100 Subject: [PATCH 3/8] fix(auth): improve oauth failover notifications --- packages/opencode/src/auth/rotating-fetch.ts | 64 ++++++++++++++------ 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/packages/opencode/src/auth/rotating-fetch.ts b/packages/opencode/src/auth/rotating-fetch.ts index ce06afaf1c3d..5f735cd50847 100644 --- a/packages/opencode/src/auth/rotating-fetch.ts +++ b/packages/opencode/src/auth/rotating-fetch.ts @@ -152,12 +152,14 @@ export function createOAuthRotatingFetch Date: Thu, 15 Jan 2026 13:38:46 +1100 Subject: [PATCH 4/8] Increase failover toast duration --- packages/opencode/src/auth/credential-manager.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/opencode/src/auth/credential-manager.ts b/packages/opencode/src/auth/credential-manager.ts index e244bc576bd5..50dd60ef1005 100644 --- a/packages/opencode/src/auth/credential-manager.ts +++ b/packages/opencode/src/auth/credential-manager.ts @@ -50,7 +50,7 @@ export namespace CredentialManager { title: "OAuth Credential Failover", message, variant: "warning", - duration: 5000, + duration: 8000, }).catch((error) => log.debug("failed to show failover toast", { error })) } } From 39569b96de4ebaab661623dac9f6a1d5ba3f871d Mon Sep 17 00:00:00 2001 From: gwizz Date: Thu, 15 Jan 2026 14:10:57 +1100 Subject: [PATCH 5/8] Improve OAuth rotation resiliency --- packages/opencode/src/auth/context.ts | 2 +- .../opencode/src/auth/credential-manager.ts | 9 +- packages/opencode/src/auth/index.ts | 344 ++++++++++-------- packages/opencode/src/auth/rotating-fetch.ts | 87 ++--- packages/opencode/src/config/config.ts | 14 + packages/opencode/src/provider/provider.ts | 6 + .../opencode/test/auth/oauth-rotation.test.ts | 111 ++++++ 7 files changed, 369 insertions(+), 204 deletions(-) diff --git a/packages/opencode/src/auth/context.ts b/packages/opencode/src/auth/context.ts index f3f6efbdbefc..ee21f02cf069 100644 --- a/packages/opencode/src/auth/context.ts +++ b/packages/opencode/src/auth/context.ts @@ -1,4 +1,4 @@ -import { AsyncLocalStorage } from "async_hooks" +import { AsyncLocalStorage } from "node:async_hooks" type Store = { oauthRecordByProvider: Map diff --git a/packages/opencode/src/auth/credential-manager.ts b/packages/opencode/src/auth/credential-manager.ts index 50dd60ef1005..c7b499f953dd 100644 --- a/packages/opencode/src/auth/credential-manager.ts +++ b/packages/opencode/src/auth/credential-manager.ts @@ -5,6 +5,7 @@ import { Log } from "../util/log" import { TuiEvent } from "../cli/cmd/tui/event" const log = Log.create({ service: "credential-manager" }) +const DEFAULT_FAILOVER_TOAST_MS = 8000 export namespace CredentialManager { export const Event = { @@ -25,11 +26,15 @@ export namespace CredentialManager { fromRecordID: string toRecordID?: string statusCode: number + toastDurationMs?: number }): Promise { const isRateLimit = input.statusCode === 429 const message = isRateLimit ? `Rate limited on "${input.providerID}". Switching OAuth credential...` - : `Auth error on "${input.providerID}". Switching OAuth credential...` + : input.statusCode === 0 + ? `Request failed on "${input.providerID}". Switching OAuth credential...` + : `Auth error on "${input.providerID}". Switching OAuth credential...` + const duration = Math.max(0, input.toastDurationMs ?? DEFAULT_FAILOVER_TOAST_MS) log.info("oauth credential failover", { providerID: input.providerID, @@ -50,7 +55,7 @@ export namespace CredentialManager { title: "OAuth Credential Failover", message, variant: "warning", - duration: 8000, + duration, }).catch((error) => log.debug("failed to show failover toast", { error })) } } diff --git a/packages/opencode/src/auth/index.ts b/packages/opencode/src/auth/index.ts index 6dc6e7fa456d..30ca1484f437 100644 --- a/packages/opencode/src/auth/index.ts +++ b/packages/opencode/src/auth/index.ts @@ -38,6 +38,10 @@ export namespace Auth { export type Info = z.infer const filepath = path.join(Global.Path.data, "auth.json") + const lockpath = `${filepath}.lock` + const STORE_LOCK_TIMEOUT_MS = 5_000 + const STORE_LOCK_STALE_MS = 30_000 + const STORE_LOCK_RETRY_MS = 25 const Health = z .object({ @@ -115,6 +119,36 @@ export namespace Auth { await fs.mkdir(path.dirname(filepath), { recursive: true }) } + async function withStoreLock(fn: () => Promise): Promise { + await ensureDataDir() + const start = Date.now() + while (true) { + try { + const handle = await fs.open(lockpath, "wx") + await handle.close() + break + } catch (error) { + const code = (error as { code?: string }).code + if (code !== "EEXIST") throw error + const stat = await fs.stat(lockpath).catch(() => undefined) + if (stat && Date.now() - stat.mtimeMs > STORE_LOCK_STALE_MS) { + await fs.rm(lockpath).catch(() => {}) + continue + } + if (Date.now() - start > STORE_LOCK_TIMEOUT_MS) { + throw new Error("Timed out waiting for auth store lock") + } + await Bun.sleep(STORE_LOCK_RETRY_MS + Math.random() * STORE_LOCK_RETRY_MS) + } + } + + try { + return await fn() + } finally { + await fs.rm(lockpath).catch(() => {}) + } + } + async function writeStoreFile(store: StoreFile): Promise { await ensureDataDir() const tempPath = `${filepath}.tmp` @@ -124,12 +158,13 @@ export namespace Auth { await fs.chmod(filepath, 0o600).catch(() => {}) } - async function loadStoreFile(): Promise { + async function readStoreFile(): Promise<{ store: StoreFile; needsWrite: boolean }> { const file = Bun.file(filepath) + const exists = await file.exists() const raw = await file.json().catch(() => undefined) const parsed = StoreFile.safeParse(raw) - if (parsed.success) return parsed.data + if (parsed.success) return { store: parsed.data, needsWrite: false } const legacyParsed = z.record(z.string(), Info).safeParse(raw) if (legacyParsed.success) { @@ -170,11 +205,31 @@ export namespace Auth { } } - await writeStoreFile(next) - return next + return { store: next, needsWrite: true } } - return { version: 2, providers: {} } + return { store: { version: 2, providers: {} }, needsWrite: exists } + } + + async function loadStoreFile(): Promise { + const result = await readStoreFile() + return result.store + } + + type StoreUpdateResult = { + value: T + changed: boolean + } + + async function updateStore(fn: (store: StoreFile) => Promise> | StoreUpdateResult) { + return withStoreLock(async () => { + const { store, needsWrite } = await readStoreFile() + const result = await fn(store) + if (result.changed || needsWrite) { + await writeStoreFile(store) + } + return result.value + }) } function ensureOAuthProvider(store: StoreFile, providerID: string): OAuthProvider { @@ -271,70 +326,69 @@ export namespace Auth { } export async function set(key: string, info: Info) { - const store = await loadStoreFile() + return updateStore(async (store) => { + if (info.type === "api") { + store.providers[key] = { type: "api", key: info.key } + return { value: undefined, changed: true } + } - if (info.type === "api") { - store.providers[key] = { type: "api", key: info.key } - await writeStoreFile(store) - return - } + if (info.type === "wellknown") { + store.providers[key] = { type: "wellknown", key: info.key, token: info.token } + return { value: undefined, changed: true } + } - if (info.type === "wellknown") { - store.providers[key] = { type: "wellknown", key: info.key, token: info.token } - await writeStoreFile(store) - return - } + const namespace = "default" + const provider = ensureOAuthProvider(store, key) + const recordID = + getOAuthRecordID(key) ?? + (await findOAuthRecordIDByRefreshToken({ providerID: key, namespace, refresh: info.refresh, provider })) ?? + provider.active[namespace] ?? + recordIDsForNamespace(provider, namespace)[0] ?? + ulid() - const namespace = "default" - const provider = ensureOAuthProvider(store, key) - const recordID = - getOAuthRecordID(key) ?? - (await findOAuthRecordIDByRefreshToken({ providerID: key, namespace, refresh: info.refresh, provider })) ?? - provider.active[namespace] ?? - recordIDsForNamespace(provider, namespace)[0] ?? - ulid() - - const now = Date.now() - const existing = findOAuthRecord(provider, recordID) - if (!existing) { - provider.records.push({ - id: recordID, - namespace, - label: "default", - accountId: info.accountId, - enterpriseUrl: info.enterpriseUrl, - refresh: info.refresh, - access: info.access, - expires: info.expires, - createdAt: now, - updatedAt: now, - health: { successCount: 0, failureCount: 0 }, - }) - provider.order[namespace] = [...(provider.order[namespace] ?? []), recordID] - } else { - existing.refresh = info.refresh - existing.access = info.access - existing.expires = info.expires - existing.updatedAt = now - if (info.accountId !== undefined) existing.accountId = info.accountId - if (info.enterpriseUrl !== undefined) existing.enterpriseUrl = info.enterpriseUrl - const order = provider.order[namespace] ?? [] - if (!order.includes(recordID)) { - provider.order[namespace] = [...order, recordID] + const now = Date.now() + const existing = findOAuthRecord(provider, recordID) + if (!existing) { + provider.records.push({ + id: recordID, + namespace, + label: "default", + accountId: info.accountId, + enterpriseUrl: info.enterpriseUrl, + refresh: info.refresh, + access: info.access, + expires: info.expires, + createdAt: now, + updatedAt: now, + health: { successCount: 0, failureCount: 0 }, + }) + provider.order[namespace] = [...(provider.order[namespace] ?? []), recordID] + } else { + existing.refresh = info.refresh + existing.access = info.access + existing.expires = info.expires + existing.updatedAt = now + if (info.accountId !== undefined) existing.accountId = info.accountId + if (info.enterpriseUrl !== undefined) existing.enterpriseUrl = info.enterpriseUrl + const order = provider.order[namespace] ?? [] + if (!order.includes(recordID)) { + provider.order[namespace] = [...order, recordID] + } } - } - provider.active[namespace] = recordID + provider.active[namespace] = recordID - await writeStoreFile(store) + return { value: undefined, changed: true } + }) } export async function remove(key: string) { - const store = await loadStoreFile() - const existing = store.providers[key] - if (!existing) return + return updateStore((store) => { + const existing = store.providers[key] + if (!existing) return { value: undefined, changed: false } - delete store.providers[key] - await writeStoreFile(store) + delete store.providers[key] + return { value: undefined, changed: true } + }) } export async function addOAuth( @@ -342,60 +396,57 @@ export namespace Auth { input: Omit, "type"> & { namespace?: string; label?: string }, ) { const namespace = (input.namespace ?? "default").trim() || "default" - const store = await loadStoreFile() + return updateStore(async (store) => { + const provider = ensureOAuthProvider(store, providerID) + const now = Date.now() + const existingRecordID = await findOAuthRecordIDByRefreshToken({ + providerID, + namespace, + refresh: input.refresh, + provider, + }) - const provider = ensureOAuthProvider(store, providerID) - const now = Date.now() - const existingRecordID = await findOAuthRecordIDByRefreshToken({ - providerID, - namespace, - refresh: input.refresh, - provider, - }) + if (existingRecordID) { + const existing = findOAuthRecord(provider, existingRecordID) + if (existing) { + existing.refresh = input.refresh + existing.access = input.access + existing.expires = input.expires + existing.updatedAt = now + if (input.accountId !== undefined) existing.accountId = input.accountId + if (input.enterpriseUrl !== undefined) existing.enterpriseUrl = input.enterpriseUrl + if (input.label) existing.label = input.label + } + const order = provider.order[namespace] ?? [] + if (!order.includes(existingRecordID)) { + provider.order[namespace] = [...order, existingRecordID] + } + provider.active[namespace] = existingRecordID - if (existingRecordID) { - const existing = findOAuthRecord(provider, existingRecordID) - if (existing) { - existing.refresh = input.refresh - existing.access = input.access - existing.expires = input.expires - existing.updatedAt = now - if (input.accountId !== undefined) existing.accountId = input.accountId - if (input.enterpriseUrl !== undefined) existing.enterpriseUrl = input.enterpriseUrl - if (input.label) existing.label = input.label - } - const order = provider.order[namespace] ?? [] - if (!order.includes(existingRecordID)) { - provider.order[namespace] = [...order, existingRecordID] + return { value: { providerID, namespace, recordID: existingRecordID }, changed: true } } - provider.active[namespace] = existingRecordID - await writeStoreFile(store) - return { providerID, namespace, recordID: existingRecordID } - } - - const recordID = ulid() - - provider.records.push({ - id: recordID, - namespace, - label: input.label ?? "default", - accountId: input.accountId, - enterpriseUrl: input.enterpriseUrl, - refresh: input.refresh, - access: input.access, - expires: input.expires, - createdAt: now, - updatedAt: now, - health: { successCount: 0, failureCount: 0 }, - }) + const recordID = ulid() - provider.order[namespace] = [...(provider.order[namespace] ?? []), recordID] - provider.active[namespace] = recordID + provider.records.push({ + id: recordID, + namespace, + label: input.label ?? "default", + accountId: input.accountId, + enterpriseUrl: input.enterpriseUrl, + refresh: input.refresh, + access: input.access, + expires: input.expires, + createdAt: now, + updatedAt: now, + health: { successCount: 0, failureCount: 0 }, + }) - await writeStoreFile(store) + provider.order[namespace] = [...(provider.order[namespace] ?? []), recordID] + provider.active[namespace] = recordID - return { providerID, namespace, recordID } + return { value: { providerID, namespace, recordID }, changed: true } + }) } export namespace OAuthPool { @@ -423,13 +474,14 @@ export namespace Auth { } export async function moveToBack(providerID: string, namespace: string, recordID: string): Promise { - const store = await loadStoreFile() - const provider = store.providers[providerID] - if (!provider || provider.type !== "oauth") return - const order = recordIDsForNamespace(provider, namespace) - provider.order[namespace] = order.filter((id) => id !== recordID).concat(recordID) - provider.active[namespace] = provider.order[namespace][0] ?? provider.active[namespace] - await writeStoreFile(store) + await updateStore((store) => { + const provider = store.providers[providerID] + if (!provider || provider.type !== "oauth") return { value: undefined, changed: false } + const order = recordIDsForNamespace(provider, namespace) + provider.order[namespace] = order.filter((id) => id !== recordID).concat(recordID) + provider.active[namespace] = provider.order[namespace][0] ?? provider.active[namespace] + return { value: undefined, changed: true } + }) } export async function recordOutcome(input: { @@ -439,40 +491,42 @@ export namespace Auth { ok: boolean cooldownUntil?: number }): Promise { - const store = await loadStoreFile() - const provider = store.providers[input.providerID] - if (!provider || provider.type !== "oauth") return - - const record = findOAuthRecord(provider, input.recordID) - if (!record) return - - const now = Date.now() - const prevCooldown = - record.health.cooldownUntil && record.health.cooldownUntil > now ? record.health.cooldownUntil : undefined - const cooldownUntil = input.ok ? undefined : input.cooldownUntil ?? prevCooldown - - record.health = { - ...record.health, - cooldownUntil, - lastStatusCode: input.statusCode, - lastErrorAt: input.ok ? undefined : now, - successCount: record.health.successCount + (input.ok ? 1 : 0), - failureCount: record.health.failureCount + (input.ok ? 0 : 1), - } - record.updatedAt = now - await writeStoreFile(store) + await updateStore((store) => { + const provider = store.providers[input.providerID] + if (!provider || provider.type !== "oauth") return { value: undefined, changed: false } + + const record = findOAuthRecord(provider, input.recordID) + if (!record) return { value: undefined, changed: false } + + const now = Date.now() + const prevCooldown = + record.health.cooldownUntil && record.health.cooldownUntil > now ? record.health.cooldownUntil : undefined + const cooldownUntil = input.ok ? undefined : input.cooldownUntil ?? prevCooldown + + record.health = { + ...record.health, + cooldownUntil, + lastStatusCode: input.statusCode, + lastErrorAt: input.ok ? undefined : now, + successCount: record.health.successCount + (input.ok ? 1 : 0), + failureCount: record.health.failureCount + (input.ok ? 0 : 1), + } + record.updatedAt = now + return { value: undefined, changed: true } + }) } export async function markAccessExpired(providerID: string, namespace: string, recordID: string): Promise { - const store = await loadStoreFile() - const provider = store.providers[providerID] - if (!provider || provider.type !== "oauth") return - const record = findOAuthRecord(provider, recordID) - if (!record || record.namespace !== namespace) return - record.access = "" - record.expires = 0 - record.updatedAt = Date.now() - await writeStoreFile(store) + await updateStore((store) => { + const provider = store.providers[providerID] + if (!provider || provider.type !== "oauth") return { value: undefined, changed: false } + const record = findOAuthRecord(provider, recordID) + if (!record || record.namespace !== namespace) return { value: undefined, changed: false } + record.access = "" + record.expires = 0 + record.updatedAt = Date.now() + return { value: undefined, changed: true } + }) } } } diff --git a/packages/opencode/src/auth/rotating-fetch.ts b/packages/opencode/src/auth/rotating-fetch.ts index 5f735cd50847..2faa333abf7f 100644 --- a/packages/opencode/src/auth/rotating-fetch.ts +++ b/packages/opencode/src/auth/rotating-fetch.ts @@ -4,12 +4,13 @@ import { CredentialManager } from "./credential-manager" const DEFAULT_RATE_LIMIT_COOLDOWN_MS = 30_000 const DEFAULT_AUTH_FAILURE_COOLDOWN_MS = 5 * 60_000 +const DEFAULT_MAX_ATTEMPTS = 5 function isReadableStream(value: unknown): value is ReadableStream { return typeof ReadableStream !== "undefined" && value instanceof ReadableStream } -function isAsyncIterable(value: unknown): boolean { +function isAsyncIterable(value: unknown): value is AsyncIterable { return typeof value === "object" && value !== null && Symbol.asyncIterator in value } @@ -53,6 +54,9 @@ export function createOAuthRotatingFetch candidates.length) { @@ -89,12 +96,7 @@ export function createOAuthRotatingFetch { - if (attempted.has(id)) return false - const cooldownUntil = recordByID.get(id)?.health.cooldownUntil - return !cooldownUntil || cooldownUntil <= now - }) ?? candidates.find((id) => !attempted.has(id)) + const nextID = pickNextCandidate(now) if (!nextID) break attempted.add(nextID) @@ -113,6 +115,17 @@ export function createOAuthRotatingFetch withOAuthRecord(opts.providerID, nextID, () => fetchFn(attemptInput, init)) + const notifyFailover = async (statusCode: number) => { + const candidate = pickNextCandidate(Date.now()) + if (!candidate) return + await CredentialManager.notifyFailover({ + providerID: opts.providerID, + fromRecordID: nextID, + toRecordID: candidate, + statusCode, + toastDurationMs: opts.toastDurationMs, + }) + } let response: Response try { @@ -126,6 +139,7 @@ export function createOAuthRotatingFetch { expect(afterRateLimit[0]).toBe(a2.recordID) expect(afterRateLimit[1]).toBe(a1.recordID) }) + + test("does not retry non-replayable bodies but rotates for the next request", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const counts = new Map() + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + counts.set(auth.refresh, (counts.get(auth.refresh) ?? 0) + 1) + return new Response("rate limited", { status: 429 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const body = new ReadableStream({ + start(controller) { + controller.enqueue(new TextEncoder().encode("payload")) + controller.close() + }, + }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body }) + + expect(response.status).toBe(429) + expect(counts.get("r1") ?? 0).toBe(1) + expect(counts.get("r2") ?? 0).toBe(0) + + const order = await Auth.OAuthPool.orderedIDs(providerID) + expect(order[0]).toBe(a2.recordID) + expect(order[1]).toBe(a1.recordID) + }) + + test("returns the last response when all credentials are exhausted", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + return new Response("rate limited", { status: 429 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + + expect(response.status).toBe(429) + + const records = await Auth.OAuthPool.list(providerID) + const recordByID = new Map(records.map((record) => [record.id, record])) + expect(recordByID.get(a1.recordID)?.health.failureCount ?? 0).toBe(1) + expect(recordByID.get(a2.recordID)?.health.failureCount ?? 0).toBe(1) + }) + + test("fails over when a request throws", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + if (auth.refresh === "r1") { + throw new Error("network down") + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + + expect(response.status).toBe(200) + + const records = await Auth.OAuthPool.list(providerID) + const recordByID = new Map(records.map((record) => [record.id, record])) + expect(recordByID.get(a1.recordID)?.health.failureCount ?? 0).toBe(1) + expect(recordByID.get(a2.recordID)?.health.failureCount ?? 0).toBe(0) + }) }) From 682bcde1d9c629ebe28251e795f9cf0e2f4ebb9e Mon Sep 17 00:00:00 2001 From: gwizz Date: Thu, 15 Jan 2026 14:21:55 +1100 Subject: [PATCH 6/8] Add OAuth retry edge case tests --- .../opencode/test/auth/oauth-rotation.test.ts | 92 +++++++++++++++++++ packages/plugin/src/index.ts | 4 +- 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/packages/opencode/test/auth/oauth-rotation.test.ts b/packages/opencode/test/auth/oauth-rotation.test.ts index fd8c870fff60..4c2b631c87fc 100644 --- a/packages/opencode/test/auth/oauth-rotation.test.ts +++ b/packages/opencode/test/auth/oauth-rotation.test.ts @@ -321,4 +321,96 @@ describe("OAuth subscription failover", () => { expect(recordByID.get(a1.recordID)?.health.failureCount ?? 0).toBe(1) expect(recordByID.get(a2.recordID)?.health.failureCount ?? 0).toBe(0) }) + + test("respects Retry-After HTTP date headers", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const originalNow = Date.now + const now = 1_700_000_000_000 + Date.now = () => now + + try { + const retryAt = new Date(now + 5_000).toUTCString() + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + if (auth.refresh === "r1") { + return new Response("rate limited", { + status: 429, + headers: { + "Retry-After": retryAt, + }, + }) + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + + expect(response.status).toBe(200) + + const records = await Auth.OAuthPool.list(providerID) + const recordByID = new Map(records.map((record) => [record.id, record])) + expect(recordByID.get(a1.recordID)?.health.cooldownUntil).toBe(now + 5_000) + } finally { + Date.now = originalNow + } + }) + + test("falls back when Request.clone throws", async () => { + await Auth.remove(providerID) + + await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const counts = new Map() + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + counts.set(auth.refresh, (counts.get(auth.refresh) ?? 0) + 1) + + if (auth.refresh === "r1") { + return new Response("rate limited", { status: 429 }) + } + + return new Response("ok", { status: 200 }) + } + + const request = new Request("https://example.com", { method: "POST" }) + ;(request as { clone: () => Request }).clone = () => { + throw new Error("clone failed") + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover(request) + + expect(response.status).toBe(429) + expect(counts.get("r1") ?? 0).toBe(1) + expect(counts.get("r2") ?? 0).toBe(0) + }) }) diff --git a/packages/plugin/src/index.ts b/packages/plugin/src/index.ts index 712193bd8e40..b95b44fe5733 100644 --- a/packages/plugin/src/index.ts +++ b/packages/plugin/src/index.ts @@ -115,7 +115,7 @@ export type AuthOuathResult = { url: string; instructions: string } & ( access: string expires: number accountId?: string - enterpriseUrl?: string + enterpriseUrl?: string // Used for GitHub Copilot Enterprise auth flows. } | { key: string } )) @@ -136,7 +136,7 @@ export type AuthOuathResult = { url: string; instructions: string } & ( access: string expires: number accountId?: string - enterpriseUrl?: string + enterpriseUrl?: string // Used for GitHub Copilot Enterprise auth flows. } | { key: string } )) From 8b0b2633b0b018bba4f2e2facc66d8dfbb023176 Mon Sep 17 00:00:00 2001 From: gwizz Date: Thu, 15 Jan 2026 14:40:05 +1100 Subject: [PATCH 7/8] Relax auth store updates under lock contention --- packages/opencode/src/auth/index.ts | 62 ++++++++++++++++---- packages/opencode/src/auth/rotating-fetch.ts | 3 +- 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/packages/opencode/src/auth/index.ts b/packages/opencode/src/auth/index.ts index 30ca1484f437..7e0c306c78e4 100644 --- a/packages/opencode/src/auth/index.ts +++ b/packages/opencode/src/auth/index.ts @@ -4,6 +4,7 @@ import fs from "fs/promises" import z from "zod" import { ulid } from "ulid" import { getOAuthRecordID } from "./context" +import { Log } from "../util/log" export const OAUTH_DUMMY_KEY = "opencode-oauth-dummy-key" @@ -42,6 +43,17 @@ export namespace Auth { const STORE_LOCK_TIMEOUT_MS = 5_000 const STORE_LOCK_STALE_MS = 30_000 const STORE_LOCK_RETRY_MS = 25 + const STORE_LOCK_BEST_EFFORT_TIMEOUT_MS = 250 + const STORE_LOCK_BEST_EFFORT_RETRY_MS = 10 + + const log = Log.create({ service: "auth.store" }) + + class StoreLockTimeoutError extends Error { + constructor() { + super("Timed out waiting for auth store lock") + this.name = "StoreLockTimeoutError" + } + } const Health = z .object({ @@ -119,8 +131,14 @@ export namespace Auth { await fs.mkdir(path.dirname(filepath), { recursive: true }) } - async function withStoreLock(fn: () => Promise): Promise { + async function withStoreLock( + fn: () => Promise, + options: { timeoutMs?: number; staleMs?: number; retryMs?: number } = {}, + ): Promise { await ensureDataDir() + const timeoutMs = options.timeoutMs ?? STORE_LOCK_TIMEOUT_MS + const staleMs = options.staleMs ?? STORE_LOCK_STALE_MS + const retryMs = options.retryMs ?? STORE_LOCK_RETRY_MS const start = Date.now() while (true) { try { @@ -131,14 +149,14 @@ export namespace Auth { const code = (error as { code?: string }).code if (code !== "EEXIST") throw error const stat = await fs.stat(lockpath).catch(() => undefined) - if (stat && Date.now() - stat.mtimeMs > STORE_LOCK_STALE_MS) { + if (stat && Date.now() - stat.mtimeMs > staleMs) { await fs.rm(lockpath).catch(() => {}) continue } - if (Date.now() - start > STORE_LOCK_TIMEOUT_MS) { - throw new Error("Timed out waiting for auth store lock") + if (Date.now() - start > timeoutMs) { + throw new StoreLockTimeoutError() } - await Bun.sleep(STORE_LOCK_RETRY_MS + Math.random() * STORE_LOCK_RETRY_MS) + await Bun.sleep(retryMs + Math.random() * retryMs) } } @@ -221,7 +239,10 @@ export namespace Auth { changed: boolean } - async function updateStore(fn: (store: StoreFile) => Promise> | StoreUpdateResult) { + async function updateStoreWithLock( + fn: (store: StoreFile) => Promise> | StoreUpdateResult, + lockOptions?: { timeoutMs?: number; staleMs?: number; retryMs?: number }, + ) { return withStoreLock(async () => { const { store, needsWrite } = await readStoreFile() const result = await fn(store) @@ -229,7 +250,28 @@ export namespace Auth { await writeStoreFile(store) } return result.value - }) + }, lockOptions) + } + + async function updateStore(fn: (store: StoreFile) => Promise> | StoreUpdateResult) { + return updateStoreWithLock(fn) + } + + async function updateStoreBestEffort( + fn: (store: StoreFile) => Promise> | StoreUpdateResult, + ): Promise { + try { + await updateStoreWithLock(fn, { + timeoutMs: STORE_LOCK_BEST_EFFORT_TIMEOUT_MS, + retryMs: STORE_LOCK_BEST_EFFORT_RETRY_MS, + }) + } catch (error) { + if (error instanceof StoreLockTimeoutError) { + log.warn("auth store lock busy, skipping update", { timeoutMs: STORE_LOCK_BEST_EFFORT_TIMEOUT_MS }) + return + } + throw error + } } function ensureOAuthProvider(store: StoreFile, providerID: string): OAuthProvider { @@ -474,7 +516,7 @@ export namespace Auth { } export async function moveToBack(providerID: string, namespace: string, recordID: string): Promise { - await updateStore((store) => { + await updateStoreBestEffort((store) => { const provider = store.providers[providerID] if (!provider || provider.type !== "oauth") return { value: undefined, changed: false } const order = recordIDsForNamespace(provider, namespace) @@ -491,7 +533,7 @@ export namespace Auth { ok: boolean cooldownUntil?: number }): Promise { - await updateStore((store) => { + await updateStoreBestEffort((store) => { const provider = store.providers[input.providerID] if (!provider || provider.type !== "oauth") return { value: undefined, changed: false } @@ -517,7 +559,7 @@ export namespace Auth { } export async function markAccessExpired(providerID: string, namespace: string, recordID: string): Promise { - await updateStore((store) => { + await updateStoreBestEffort((store) => { const provider = store.providers[providerID] if (!provider || provider.type !== "oauth") return { value: undefined, changed: false } const record = findOAuthRecord(provider, recordID) diff --git a/packages/opencode/src/auth/rotating-fetch.ts b/packages/opencode/src/auth/rotating-fetch.ts index 2faa333abf7f..8714a111c4bd 100644 --- a/packages/opencode/src/auth/rotating-fetch.ts +++ b/packages/opencode/src/auth/rotating-fetch.ts @@ -4,7 +4,6 @@ import { CredentialManager } from "./credential-manager" const DEFAULT_RATE_LIMIT_COOLDOWN_MS = 30_000 const DEFAULT_AUTH_FAILURE_COOLDOWN_MS = 5 * 60_000 -const DEFAULT_MAX_ATTEMPTS = 5 function isReadableStream(value: unknown): value is ReadableStream { return typeof ReadableStream !== "undefined" && value instanceof ReadableStream @@ -74,7 +73,7 @@ export function createOAuthRotatingFetch Date: Fri, 16 Jan 2026 00:51:17 +1100 Subject: [PATCH 8/8] Rotate oauth on non-network errors Non-network errors seen in logs: - AI_APICallError (402 deactivated_workspace) - AI_APICallError (500 server_error) - AI_LoadAPIKeyError / OpenAI API key is missing - ProviderInitError - ConfigInvalidError - ProviderAuthOauthCallbackFailed - NotFoundError - EditBuffer is destroyed --- packages/opencode/src/auth/rotating-fetch.ts | 138 +++++++++++++----- packages/opencode/src/config/config.ts | 6 + packages/opencode/src/provider/provider.ts | 1 + .../opencode/test/auth/oauth-rotation.test.ts | 59 +++++++- 4 files changed, 163 insertions(+), 41 deletions(-) diff --git a/packages/opencode/src/auth/rotating-fetch.ts b/packages/opencode/src/auth/rotating-fetch.ts index 8714a111c4bd..7e03b5af0957 100644 --- a/packages/opencode/src/auth/rotating-fetch.ts +++ b/packages/opencode/src/auth/rotating-fetch.ts @@ -4,6 +4,7 @@ import { CredentialManager } from "./credential-manager" const DEFAULT_RATE_LIMIT_COOLDOWN_MS = 30_000 const DEFAULT_AUTH_FAILURE_COOLDOWN_MS = 5 * 60_000 +const DEFAULT_NETWORK_RETRY_ATTEMPTS = 1 function isReadableStream(value: unknown): value is ReadableStream { return typeof ReadableStream !== "undefined" && value instanceof ReadableStream @@ -43,6 +44,56 @@ function parseRetryAfterMs(response: Response): number | undefined { return undefined } +const NETWORK_ERROR_CODES = new Set([ + "ECONNRESET", + "ECONNREFUSED", + "EHOSTUNREACH", + "ENETUNREACH", + "ENOTFOUND", + "EAI_AGAIN", + "ETIMEDOUT", + "ECONNABORTED", + "EPIPE", + "UND_ERR_CONNECT_TIMEOUT", + "UND_ERR_HEADERS_TIMEOUT", + "UND_ERR_BODY_TIMEOUT", + "UND_ERR_SOCKET", +]) +const NETWORK_ERROR_NAMES = new Set(["AbortError", "TimeoutError", "FetchError"]) + +function extractErrorCode(error: unknown): string | undefined { + if (!error || typeof error !== "object") return undefined + const code = (error as { code?: unknown }).code + return typeof code === "string" ? code : undefined +} + +function extractErrorName(error: unknown): string | undefined { + if (!error || typeof error !== "object") return undefined + const name = (error as { name?: unknown }).name + return typeof name === "string" ? name : undefined +} + +function extractErrorMessage(error: unknown): string | undefined { + if (!error || typeof error !== "object") return undefined + const message = (error as { message?: unknown }).message + return typeof message === "string" ? message : undefined +} + +function isNetworkError(error: unknown): boolean { + const directCode = extractErrorCode(error) + const cause = typeof error === "object" && error !== null ? (error as { cause?: unknown }).cause : undefined + const causeCode = extractErrorCode(cause) + const code = directCode ?? causeCode + if (code && NETWORK_ERROR_CODES.has(code)) return true + + const name = extractErrorName(error) + if (name && NETWORK_ERROR_NAMES.has(name)) return true + + const message = extractErrorMessage(error)?.toLowerCase() + if (!message) return false + return message.includes("fetch failed") || message.includes("network error") || message.includes("network down") +} + function isAuthExpiredStatus(status: number): boolean { return status === 401 || status === 403 } @@ -55,6 +106,7 @@ export function createOAuthRotatingFetch attempt + 1 < maxAttempts + let networkRetryAttempts = allowRetry ? configuredNetworkRetryAttempts : 0 + + const runWithNetworkRetry = async (): Promise => { + for (let networkAttempt = 0; ; networkAttempt++) { + let attemptInput = input + if (inputIsRequest && allowRetry) { + try { + attemptInput = input.clone() + } catch (e) { + lastError = e + allowRetry = false + networkRetryAttempts = 0 + maxAttempts = attempt + 1 + } + } + + try { + return await withOAuthRecord(opts.providerID, nextID, () => fetchFn(attemptInput, init)) + } catch (e) { + lastError = e + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: 0, + ok: false, + }) + const networkError = isNetworkError(e) + if (networkError && allowRetry && networkAttempt < networkRetryAttempts) { + continue + } + throw e + } } } - - const hasMoreAttempts = attempt + 1 < maxAttempts - - const run = () => withOAuthRecord(opts.providerID, nextID, () => fetchFn(attemptInput, init)) const notifyFailover = async (statusCode: number) => { const candidate = pickNextCandidate(Date.now()) if (!candidate) return @@ -128,18 +205,13 @@ export function createOAuthRotatingFetch { expect(order[1]).toBe(a1.recordID) }) + test("fails over on non-auth errors", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + if (auth.refresh === "r1") { + return new Response("payment required", { status: 402 }) + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + + expect(response.status).toBe(200) + + const order = await Auth.OAuthPool.orderedIDs(providerID) + expect(order[0]).toBe(a2.recordID) + expect(order[1]).toBe(a1.recordID) + }) + test("sticks to the active credential until rate limited", async () => { await Auth.remove(providerID) @@ -285,7 +321,7 @@ describe("OAuth subscription failover", () => { expect(recordByID.get(a2.recordID)?.health.failureCount ?? 0).toBe(1) }) - test("fails over when a request throws", async () => { + test("retries network errors without rotating", async () => { await Auth.remove(providerID) const a1 = await Auth.addOAuth(providerID, { @@ -299,13 +335,20 @@ describe("OAuth subscription failover", () => { expires: Date.now() + 60_000, }) + const counts = new Map() + let failures = 0 const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { const auth = await Auth.get(providerID) expect(auth?.type).toBe("oauth") if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) - if (auth.refresh === "r1") { - throw new Error("network down") + counts.set(auth.refresh, (counts.get(auth.refresh) ?? 0) + 1) + + if (auth.refresh === "r1" && failures < 1) { + failures += 1 + const error = new Error("network down") + ;(error as { code?: string }).code = "ECONNRESET" + throw error } return new Response("ok", { status: 200 }) @@ -316,10 +359,12 @@ describe("OAuth subscription failover", () => { expect(response.status).toBe(200) - const records = await Auth.OAuthPool.list(providerID) - const recordByID = new Map(records.map((record) => [record.id, record])) - expect(recordByID.get(a1.recordID)?.health.failureCount ?? 0).toBe(1) - expect(recordByID.get(a2.recordID)?.health.failureCount ?? 0).toBe(0) + expect(counts.get("r1") ?? 0).toBe(2) + expect(counts.get("r2") ?? 0).toBe(0) + + const order = await Auth.OAuthPool.orderedIDs(providerID) + expect(order[0]).toBe(a1.recordID) + expect(order[1]).toBe(a2.recordID) }) test("respects Retry-After HTTP date headers", async () => {