diff --git a/packages/app/src/context/local.tsx b/packages/app/src/context/local.tsx index ac5da60e862..d504983f03b 100644 --- a/packages/app/src/context/local.tsx +++ b/packages/app/src/context/local.tsx @@ -90,9 +90,9 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ }) const resolveConfigured = () => { - if (!sync.data.config.model) return - const [providerID, modelID] = sync.data.config.model.split("/") - const key = { providerID, modelID } + const configured = sync.data.config.model + if (!configured) return + const key = { providerID: configured.providerID, modelID: configured.id } if (isModelValid(key)) return key } diff --git a/packages/desktop/src/types/tauri-clipboard-manager.d.ts b/packages/desktop/src/types/tauri-clipboard-manager.d.ts new file mode 100644 index 00000000000..41f30aa61da --- /dev/null +++ b/packages/desktop/src/types/tauri-clipboard-manager.d.ts @@ -0,0 +1,11 @@ +declare module "@tauri-apps/plugin-clipboard-manager" { + type ClipboardImage = { + rgba(): Promise + size(): Promise<{ + width: number + height: number + }> + } + + export function readImage(): Promise +} diff --git a/packages/opencode/src/auth/index.ts b/packages/opencode/src/auth/index.ts index ce948b92ac8..972dd4d4f7d 100644 --- a/packages/opencode/src/auth/index.ts +++ b/packages/opencode/src/auth/index.ts @@ -13,6 +13,7 @@ export namespace Auth { expires: z.number(), accountId: z.string().optional(), enterpriseUrl: z.string().optional(), + email: z.string().optional(), // For multi-account identification }) .meta({ ref: "OAuth" }) @@ -20,6 +21,7 @@ export namespace Auth { .object({ type: z.literal("api"), key: z.string(), + email: z.string().optional(), // For multi-account identification }) .meta({ ref: "ApiAuth" }) @@ -34,37 +36,348 @@ export namespace Auth { export const Info = z.discriminatedUnion("type", [Oauth, Api, WellKnown]).meta({ ref: "Auth" }) export type Info = z.infer + // Multi-account storage format + export type AccountStore = { + [provider: string]: { + accounts: { + [accountId: string]: Info & { disabled?: boolean } + } + // Current active account for this provider + activeAccount?: string + } + } + const filepath = path.join(Global.Path.data, "auth.json") - export async function get(providerID: string) { - const auth = await all() - return auth[providerID] + /** + * Get credential for a provider. + * Uses activeAccount if set, otherwise returns first available. + * Precedence order: + * 1. Explicit account parameter (if provided) + * 2. Config project override (auth.account in opencode.json) + * 3. Environment variable overrides: + * - OPENCODE_ACCOUNT_: e.g., OPENCODE_ACCOUNT_OPENAI=work + * - OPENCODE_ACCOUNT: general override for any provider + * 4. Active account for the provider + * 5. First available non-disabled account + */ + export async function get(providerID: string, explicitAccount?: string): Promise { + const store = await load() + const provider = store[providerID] + + if (!provider || !provider.accounts) return undefined + + // 1. Use explicit account if provided + if (explicitAccount && provider.accounts[explicitAccount] && !provider.accounts[explicitAccount].disabled) { + return provider.accounts[explicitAccount] + } + + // 2. Check for config project override (lazy import to avoid circular dependency) + let configAccount: string | undefined + try { + const { Config } = await import("../config/config") + const config = await Config.get() + configAccount = config.auth?.[providerID] + if (configAccount && provider.accounts[configAccount] && !provider.accounts[configAccount].disabled) { + return provider.accounts[configAccount] + } + } catch { + // Config may not be available in all contexts + } + + // 3. Check for environment variable overrides + const envVarName = `OPENCODE_ACCOUNT_${providerID.toUpperCase()}` + const envAccount = process.env[envVarName] || process.env["OPENCODE_ACCOUNT"] + + if (envAccount) { + // Use the specific account from env var + const account = provider.accounts[envAccount] + if (account && !account.disabled) { + return account + } + // If account not found or disabled, fall through to active account + } + + // 4. Use active account if set + if (provider.activeAccount && provider.accounts[provider.activeAccount]) { + return provider.accounts[provider.activeAccount] + } + + // 5. Otherwise, find first non-disabled account + for (const [id, info] of Object.entries(provider.accounts)) { + if (!info.disabled) return info + } + + return undefined } - export async function all(): Promise> { - const file = Bun.file(filepath) - const data = await file.json().catch(() => ({}) as Record) - return Object.entries(data).reduce( - (acc, [key, value]) => { - const parsed = Info.safeParse(value) - if (!parsed.success) return acc - acc[key] = parsed.data - return acc - }, - {} as Record, + /** + * Get all accounts for a provider. + */ + export async function getAccounts(providerID: string): Promise> { + const store = await load() + const provider = store[providerID] + return provider?.accounts ?? {} + } + + /** + * List all providers and their accounts. + */ + export async function all(): Promise { + return await load() + } + + /** + * Add a new credential to a provider. + * Automatically creates a new account entry. + * Returns the account ID (email or generated). + */ + export async function add(providerID: string, info: Info): Promise { + const store = await load() + + if (!store[providerID]) { + store[providerID] = { accounts: {} } + } + + // Generate account ID from email if available, otherwise use timestamp + let accountId = "default" + if ("email" in info && info.email) { + accountId = info.email + } else if ("refresh" in info && info.refresh) { + // For OAuth, use a hash of refresh token + accountId = `oauth-${Date.now()}` + } else { + accountId = `key-${Date.now()}` + } + + store[providerID].accounts[accountId] = info + + // If this is the first account, set as active + if (!store[providerID].activeAccount) { + store[providerID].activeAccount = accountId + } + + await save(store) + return accountId + } + + /** + * Set credential (alias for add for compatibility) + */ + export async function set(providerID: string, info: Info, account?: string) { + // For backwards compatibility: if account is provided, use it + if (account) { + const store = await load() + if (!store[providerID]) { + store[providerID] = { accounts: {} } + } + store[providerID].accounts[account] = info + if (!store[providerID].activeAccount) { + store[providerID].activeAccount = account + } + await save(store) + } else { + // Otherwise add as new account + await add(providerID, info) + } + } + + /** + * Remove an account from a provider. + * If account is "all" or not specified, removes all. + */ + export async function remove(providerID: string, account?: string) { + const store = await load() + const provider = store[providerID] + + if (!provider) return + + if (!account) { + // Remove all accounts for this provider + delete store[providerID] + } else if (account === "all") { + delete store[providerID] + } else { + delete provider.accounts[account] + + // If we removed the active account, switch to another + if (provider.activeAccount === account) { + const remaining = Object.keys(provider.accounts) + provider.activeAccount = remaining[0] ?? undefined + } + } + + await save(store) + } + + /** + * List all accounts for a provider. + */ + export async function list(providerID: string): Promise { + const store = await load() + const provider = store[providerID] + return provider ? Object.keys(provider.accounts) : [] + } + + /** + * Get active account for a provider. + */ + export async function getActiveAccount(providerID: string): Promise { + const store = await load() + return store[providerID]?.activeAccount + } + + /** + * Set active account for a provider. + */ + export async function use(providerID: string, account: string) { + const store = await load() + const provider = store[providerID] + + if (!provider || !provider.accounts[account]) { + throw new Error(`Account ${account} not found for provider ${providerID}`) + } + + provider.activeAccount = account + await save(store) + } + + /** + * Enable/disable an account (for rotation). + */ + export async function setEnabled(providerID: string, account: string, enabled: boolean) { + const store = await load() + const provider = store[providerID] + + if (!provider || !provider.accounts[account]) return + + provider.accounts[account].disabled = !enabled + + // If we disabled the active account, switch to another + if (!enabled && provider.activeAccount === account) { + const remaining = Object.keys(provider.accounts).filter((a) => !provider.accounts[a].disabled) + provider.activeAccount = remaining[0] ?? undefined + } + + await save(store) + } + + /** + * Get next available account (for auto-rotation on rate-limit). + */ + export async function getNextAccount(providerID: string): Promise<{ account: string; info: Info } | undefined> { + const store = await load() + const provider = store[providerID] + + if (!provider || !provider.accounts) return undefined + + const accounts = Object.entries(provider.accounts).filter(([_, info]) => !info.disabled) + + if (accounts.length === 0) return undefined + + // Simple round-robin: switch to next account + const currentActive = provider.activeAccount + const currentIndex = accounts.findIndex(([id]) => id === currentActive) + const nextIndex = (currentIndex + 1) % accounts.length + + const [account, info] = accounts[nextIndex] + provider.activeAccount = account + await save(store) + + return { account, info } + } + + /** + * Check if an error is a rate-limit error (429). + */ + export function isRateLimitError(error: unknown): boolean { + if (!error) return false + + const errorObj = error as any + const status = errorObj?.status || errorObj?.statusCode + + if (status === 429) return true + + // Check for common rate-limit messages + const message = errorObj?.message || String(error) + const lowerMessage = message.toLowerCase() + return ( + lowerMessage.includes("rate limit") || + lowerMessage.includes("too many requests") || + lowerMessage.includes("rate_limit") || + lowerMessage.includes("429") ) } - export async function set(key: string, info: Info) { + /** + * Execute a function with automatic account rotation on rate-limit. + * Returns the result if successful, or throws if all accounts are exhausted. + */ + export async function withRetry( + providerID: string, + fn: (info: Info) => Promise, + maxRetries?: number, + ): Promise { + const max = maxRetries ?? 10 + let lastError: unknown + + for (let i = 0; i < max; i++) { + const info = await get(providerID) + if (!info) { + throw new Error(`No credentials found for provider ${providerID}`) + } + + try { + return await fn(info) + } catch (error) { + lastError = error + + if (!isRateLimitError(error)) { + // Not a rate-limit error, throw immediately + throw error + } + + // Rate-limit error - try next account + console.log(`[auth] Rate limited on ${providerID}, switching account...`) + + const next = await getNextAccount(providerID) + if (!next) { + throw new Error(`Rate limited and no more accounts available for ${providerID}`) + } + } + } + + throw lastError + } + + /** + * Legacy compatibility: convert old format to new. + */ + async function migrateIfNeeded(store: AccountStore): Promise { + // Check if it's in legacy format (direct Info objects instead of { accounts: {} }) + for (const [providerID, value] of Object.entries(store)) { + if (value && typeof value === "object" && !("accounts" in value)) { + // Legacy format - migrate + const info = value as unknown as Info + store[providerID] = { + accounts: { default: info }, + activeAccount: "default", + } + } + } + return store + } + + // Load auth data from file + async function load(): Promise { const file = Bun.file(filepath) - const data = await all() - await Bun.write(file, JSON.stringify({ ...data, [key]: info }, null, 2), { mode: 0o600 }) + const data = await file.json().catch(() => ({})) + return await migrateIfNeeded(data) } - export async function remove(key: string) { + // Save auth data to file + async function save(store: AccountStore) { const file = Bun.file(filepath) - const data = await all() - delete data[key] - await Bun.write(file, JSON.stringify(data, null, 2), { mode: 0o600 }) + await Bun.write(file, JSON.stringify(store, null, 2), { mode: 0o600 }) } } diff --git a/packages/opencode/src/cli/cmd/auth.ts b/packages/opencode/src/cli/cmd/auth.ts index 34e2269d0c1..07663017869 100644 --- a/packages/opencode/src/cli/cmd/auth.ts +++ b/packages/opencode/src/cli/cmd/auth.ts @@ -16,7 +16,6 @@ type PluginAuth = NonNullable /** * Handle plugin-based authentication flow. - * Returns true if auth was handled, false if it should fall through to default handling. */ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): Promise { let index = 0 @@ -35,7 +34,6 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): } const method = plugin.auth.methods[index] - // Handle prompts for all auth types await Bun.sleep(10) const inputs: Record = {} if (method.prompts) { @@ -83,7 +81,7 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): const saveProvider = result.provider ?? provider if ("refresh" in result) { const { type: _, provider: __, refresh, access, expires, ...extraFields } = result - await Auth.set(saveProvider, { + await Auth.add(saveProvider, { type: "oauth", refresh, access, @@ -92,7 +90,7 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): }) } if ("key" in result) { - await Auth.set(saveProvider, { + await Auth.add(saveProvider, { type: "api", key: result.key, }) @@ -115,7 +113,7 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): const saveProvider = result.provider ?? provider if ("refresh" in result) { const { type: _, provider: __, refresh, access, expires, ...extraFields } = result - await Auth.set(saveProvider, { + await Auth.add(saveProvider, { type: "oauth", refresh, access, @@ -124,7 +122,7 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): }) } if ("key" in result) { - await Auth.set(saveProvider, { + await Auth.add(saveProvider, { type: "api", key: result.key, }) @@ -145,7 +143,7 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): } if (result.type === "success") { const saveProvider = result.provider ?? provider - await Auth.set(saveProvider, { + await Auth.add(saveProvider, { type: "api", key: result.key, }) @@ -163,29 +161,48 @@ export const AuthCommand = cmd({ command: "auth", describe: "manage credentials", builder: (yargs) => - yargs.command(AuthLoginCommand).command(AuthLogoutCommand).command(AuthListCommand).demandCommand(), + yargs + .command(AuthLoginCommand) + .command(AuthLogoutCommand) + .command(AuthListCommand) + .command(AuthUseCommand) + .demandCommand(), async handler() {}, }) export const AuthListCommand = cmd({ command: "list", aliases: ["ls"], - describe: "list providers", + describe: "list providers and accounts", async handler() { UI.empty() const authPath = path.join(Global.Path.data, "auth.json") const homedir = os.homedir() const displayPath = authPath.startsWith(homedir) ? authPath.replace(homedir, "~") : authPath prompts.intro(`Credentials ${UI.Style.TEXT_DIM}${displayPath}`) - const results = Object.entries(await Auth.all()) + const results = await Auth.all() const database = await ModelsDev.get() - for (const [providerID, result] of results) { + // Group by provider + for (const [providerID, providerData] of Object.entries(results)) { const name = database[providerID]?.name || providerID - prompts.log.info(`${name} ${UI.Style.TEXT_DIM}${result.type}`) + + // Show provider name + prompts.log.info(`${UI.Style.TEXT_NORMAL_BOLD}${name}`) + + // Show all accounts for this provider + if (providerData.accounts) { + for (const [accountId, info] of Object.entries(providerData.accounts)) { + const isActive = accountId === providerData.activeAccount + const isDisabled = "disabled" in info && info.disabled + const marker = isActive ? " ✓" : (isDisabled ? " (disabled)" : "") + const label = accountId === "default" ? "default" : accountId + prompts.log.info(` ${label}${marker} ${UI.Style.TEXT_DIM}(${info.type})`) + } + } } - prompts.outro(`${results.length} credentials`) + prompts.outro(`${Object.keys(results).length} providers`) // Environment variables section const activeEnvVars: Array<{ provider: string; envVar: string }> = [] @@ -228,7 +245,12 @@ export const AuthLoginCommand = cmd({ async fn() { UI.empty() prompts.intro("Add credential") + + // Check if provider already has accounts - offer options + const existingProviders = await Auth.all() + if (args.url) { + // Well-known auth const wellknown = await fetch(`${args.url}/.well-known/opencode`).then((x) => x.json() as any) prompts.log.info(`Running \`${wellknown.auth.command.join(" ")}\``) const proc = Bun.spawn({ @@ -242,7 +264,7 @@ export const AuthLoginCommand = cmd({ return } const token = await new Response(proc.stdout).text() - await Auth.set(args.url, { + await Auth.add(args.url, { type: "wellknown", key: wellknown.auth.env, token: token.trim(), @@ -251,6 +273,7 @@ export const AuthLoginCommand = cmd({ prompts.outro("Done") return } + await ModelsDev.refresh().catch(() => {}) const config = await Config.get() @@ -307,6 +330,63 @@ export const AuthLoginCommand = cmd({ if (prompts.isCancel(provider)) throw new UI.CancelledError() + // Check if this provider already has accounts + const hasExistingAccounts = existingProviders[provider] && + Object.keys(existingProviders[provider].accounts || {}).length > 0 + + if (hasExistingAccounts) { + // Ask what to do: add another, switch, or manage + const action = await prompts.select({ + message: "This provider already has accounts. What would you like to do?", + options: [ + { label: "Add another account", value: "add" }, + { label: "Switch active account", value: "switch" }, + { label: "Manage accounts (enable/disable)", value: "manage" }, + ], + }) + if (prompts.isCancel(action)) throw new UI.CancelledError() + + if (action === "switch") { + const accounts = await Auth.list(provider) + const currentActive = await Auth.getActiveAccount(provider) + const selected = await prompts.select({ + message: "Select active account", + options: accounts.map(acc => ({ + label: acc === "default" ? "default" : acc, + value: acc, + })), + }) + if (prompts.isCancel(selected)) throw new UI.CancelledError() + await Auth.use(provider, selected) + prompts.log.success(`Switched to ${selected}`) + prompts.outro("Done") + return + } + + if (action === "manage") { + const accounts = await Auth.list(provider) + const selected = await prompts.select({ + message: "Select account to toggle", + options: [ + ...accounts.map(acc => ({ + label: acc === "default" ? "default" : acc, + value: acc, + })), + ], + }) + if (prompts.isCancel(selected)) throw new UI.CancelledError() + + const currentAccounts = await Auth.getAccounts(provider) + const isDisabled = currentAccounts[selected]?.disabled ?? false + + await Auth.setEnabled(provider, selected, isDisabled) + prompts.log.success(isDisabled ? "Account enabled" : "Account disabled") + prompts.outro("Done") + return + } + // If "add", continue to authentication + } + const plugin = await Plugin.list().then((x) => x.findLast((x) => x.auth?.provider === provider)) if (plugin && plugin.auth) { const handled = await handlePluginAuth({ auth: plugin.auth }, provider) @@ -316,13 +396,12 @@ export const AuthLoginCommand = cmd({ if (provider === "other") { provider = await prompts.text({ message: "Enter provider id", - validate: (x) => (x && x.match(/^[0-9a-z-]+$/) ? undefined : "a-z, 0-9 and hyphens only"), + validate: (x) => (x && x.match(/^[0-9a-z-]+$/)) ? undefined : "a-z, 0-9 and hyphens only", }) if (prompts.isCancel(provider)) throw new UI.CancelledError() provider = provider.replace(/^@ai-sdk\//, "") if (prompts.isCancel(provider)) throw new UI.CancelledError() - // Check if a plugin provides auth for this custom provider const customPlugin = await Plugin.list().then((x) => x.findLast((x) => x.auth?.provider === provider)) if (customPlugin && customPlugin.auth) { const handled = await handlePluginAuth({ auth: customPlugin.auth }, provider) @@ -363,11 +442,23 @@ export const AuthLoginCommand = cmd({ validate: (x) => (x && x.length > 0 ? undefined : "Required"), }) if (prompts.isCancel(key)) throw new UI.CancelledError() - await Auth.set(provider, { + + // Ask for email (optional, for identification) + const email = await prompts.text({ + message: "Account name/email (optional, for identification)", + placeholder: "e.g., work, personal, user@gmail.com", + }) + + const info: Auth.Info = { type: "api", key, - }) - + } + + if (email && !prompts.isCancel(email)) { + (info as any).email = email.trim() + } + + await Auth.add(provider, info) prompts.outro("Done") }, }) @@ -379,22 +470,121 @@ export const AuthLogoutCommand = cmd({ describe: "log out from a configured provider", async handler() { UI.empty() - const credentials = await Auth.all().then((x) => Object.entries(x)) + const credentials = await Auth.all() + const providers = Object.keys(credentials) + prompts.intro("Remove credential") - if (credentials.length === 0) { + if (providers.length === 0) { prompts.log.error("No credentials found") return } + const database = await ModelsDev.get() + + // Show provider selection with account count const providerID = await prompts.select({ message: "Select provider", - options: credentials.map(([key, value]) => ({ - label: (database[key]?.name || key) + UI.Style.TEXT_DIM + " (" + value.type + ")", - value: key, - })), + options: providers.map(key => { + const accountCount = Object.keys(credentials[key].accounts || {}).length + return { + label: (database[key]?.name || key) + UI.Style.TEXT_DIM + ` (${accountCount} account${accountCount !== 1 ? "s" : ""})`, + value: key, + } + }), }) if (prompts.isCancel(providerID)) throw new UI.CancelledError() - await Auth.remove(providerID) + + // Show account selection + const accounts = await Auth.list(providerID) + if (accounts.length > 1) { + const accountToRemove = await prompts.select({ + message: "Select account to remove", + options: [ + { label: "All accounts", value: "all" }, + ...accounts.map(acc => ({ + label: acc === "default" ? "default" : acc, + value: acc, + })), + ], + }) + if (prompts.isCancel(accountToRemove)) throw new UI.CancelledError() + + await Auth.remove(providerID, accountToRemove) + } else { + await Auth.remove(providerID) + } + prompts.outro("Logout successful") }, }) + +export const AuthUseCommand = cmd({ + command: "use", + describe: "switch between accounts for a provider", + builder: (yargs) => + yargs + .positional("provider", { + describe: "provider id", + type: "string", + }) + .positional("account", { + describe: "account name or email", + type: "string", + }), + async handler(args) { + if (!args.provider || !args.account) { + // Interactive mode + const credentials = await Auth.all() + const providers = Object.keys(credentials) + + if (providers.length === 0) { + prompts.log.error("No providers found") + return + } + + const database = await ModelsDev.get() + + const providerID = await prompts.select({ + message: "Select provider", + options: providers.map(key => ({ + label: database[key]?.name || key, + value: key, + })), + }) + if (prompts.isCancel(providerID)) throw new UI.CancelledError() + + const accounts = await Auth.list(providerID) + if (accounts.length === 0) { + prompts.log.error("No accounts found for this provider") + return + } + + const account = await prompts.select({ + message: "Select account", + options: accounts.map(acc => ({ + label: acc === "default" ? "default" : acc, + value: acc, + })), + }) + if (prompts.isCancel(account)) throw new UI.CancelledError() + + try { + await Auth.use(providerID, account) + prompts.log.success(`Switched to ${account}`) + } catch (error) { + prompts.log.error(String(error)) + } + prompts.outro("Done") + return + } + + // Direct mode + try { + await Auth.use(args.provider, args.account) + prompts.log.success(`Switched to ${args.account} for ${args.provider}`) + } catch (error) { + prompts.log.error(String(error)) + } + prompts.outro("Done") + }, +}) diff --git a/packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx b/packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx index 9682bee4ead..aad9b08957c 100644 --- a/packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx +++ b/packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx @@ -13,6 +13,7 @@ import { DialogModel } from "./dialog-model" import { useKeyboard } from "@opentui/solid" import { Clipboard } from "@tui/util/clipboard" import { useToast } from "../ui/toast" +import { Provider } from "@/provider/provider" const PROVIDER_PRIORITY: Record = { opencode: 0, @@ -40,6 +41,29 @@ export function createDialogProviderOptions() { }[provider.id], category: provider.id in PROVIDER_PRIORITY ? "Popular" : "Other", async onSelect() { + const connected = sync.data.provider_next.connected ?? [] + const isConnected = connected.includes(provider.id) + + if (isConnected) { + const action = await new Promise((resolve) => { + dialog.replace(() => ( + resolve(opt.value)} + /> + )) + }) + + if (action === "switch") { + await dialog.replace(() => ) + return + } + } + const methods = sync.data.provider_auth[provider.id] ?? [ { type: "api", @@ -92,6 +116,77 @@ export function createDialogProviderOptions() { return options } +interface SwitchAccountDialogProps { + providerID: string + providerName: string +} + +function SwitchAccountDialog(props: SwitchAccountDialogProps) { + const dialog = useDialog() + const sdk = useSDK() + const sync = useSync() + const toast = useToast() + const [accounts, setAccounts] = createSignal< + Array<{ + id: string + type: string + disabled: boolean + }> + >([]) + const [activeAccount, setActiveAccount] = createSignal(undefined) + + onMount(() => { + sdk.client.auth + .list() + .then((result) => { + const provider = result.data?.[props.providerID] + if (!provider) { + setAccounts([]) + setActiveAccount(undefined) + return + } + + setActiveAccount(provider.activeAccount) + setAccounts( + Object.entries(provider.accounts ?? {}).map(([id, info]) => ({ + id, + type: info.type, + disabled: (info as { disabled?: boolean }).disabled === true, + })), + ) + }) + .catch(toast.error) + }) + + return ( + ({ + title: account.id, + value: account.id, + description: account.id === activeAccount() ? "Active" : account.type, + footer: account.disabled ? "Disabled" : undefined, + disabled: account.disabled, + }))} + current={activeAccount()} + onSelect={(option) => { + sdk.client.auth + .use( + { + providerID: props.providerID, + account: option.value, + }, + { throwOnError: true }, + ) + .then(() => sdk.client.instance.dispose()) + .then(() => sync.bootstrap()) + .then(() => dialog.clear()) + .catch(toast.error) + }} + /> + ) +} + export function DialogProvider() { const options = createDialogProviderOptions() return diff --git a/packages/opencode/src/config/config.ts b/packages/opencode/src/config/config.ts index 8f0f583ea3d..e584f1871a8 100644 --- a/packages/opencode/src/config/config.ts +++ b/packages/opencode/src/config/config.ts @@ -76,9 +76,12 @@ export namespace Config { // 6) Inline config (OPENCODE_CONFIG_CONTENT) // Managed config directory is enterprise-only and always overrides everything above. let result: Info = {} - for (const [key, value] of Object.entries(auth)) { - if (value.type === "wellknown") { - process.env[value.key] = value.token + for (const [key, providerData] of Object.entries(auth)) { + const activeAccountId = providerData.activeAccount + if (!activeAccountId || !providerData.accounts[activeAccountId]) continue + const account = providerData.accounts[activeAccountId] + if (account.type === "wellknown") { + process.env[account.key] = account.token log.debug("fetching remote config", { url: `${key}/.well-known/opencode` }) const response = await fetch(`${key}/.well-known/opencode`) if (!response.ok) { @@ -1084,6 +1087,12 @@ export namespace Config { .record(z.string(), Provider) .optional() .describe("Custom provider configurations and model overrides"), + auth: z + .record(z.string(), z.string()) + .optional() + .describe( + 'Account to use per provider. Use provider ID as key and account name as value (e.g., { "openai": "work", "anthropic": "personal" })', + ), mcp: z .record( z.string(), diff --git a/packages/opencode/src/provider/auth.ts b/packages/opencode/src/provider/auth.ts index e6681ff0891..22258196831 100644 --- a/packages/opencode/src/provider/auth.ts +++ b/packages/opencode/src/provider/auth.ts @@ -93,7 +93,7 @@ export namespace ProviderAuth { if (result?.type === "success") { if ("key" in result) { - await Auth.set(input.providerID, { + await Auth.add(input.providerID, { type: "api", key: result.key, }) @@ -108,7 +108,7 @@ export namespace ProviderAuth { if (result.accountId) { info.accountId = result.accountId } - await Auth.set(input.providerID, info) + await Auth.add(input.providerID, info) } return } @@ -123,7 +123,7 @@ export namespace ProviderAuth { key: z.string(), }), async (input) => { - await Auth.set(input.providerID, { + await Auth.add(input.providerID, { type: "api", key: input.key, }) diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index d76cc902ae6..f4a1ae80fbc 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -848,12 +848,16 @@ export namespace Provider { } // load apikeys - for (const [providerID, provider] of Object.entries(await Auth.all())) { + for (const [providerID, providerData] of Object.entries(await Auth.all())) { if (disabled.has(providerID)) continue - if (provider.type === "api") { + const accounts = providerData.accounts + const activeAccountId = providerData.activeAccount + if (!activeAccountId || !accounts[activeAccountId]) continue + const account = accounts[activeAccountId] + if (account.type === "api") { mergeProvider(providerID, { source: "api", - key: provider.key, + key: account.key, }) } } @@ -1241,7 +1245,14 @@ export namespace Provider { } } - export function parseModel(model: string) { + export function parseModel(model: string | { providerID: string; id: string }) { + if (typeof model !== "string") { + return { + providerID: model.providerID, + modelID: model.id, + } + } + const [providerID, ...rest] = model.split("/") return { providerID: providerID, diff --git a/packages/opencode/src/server/server.ts b/packages/opencode/src/server/server.ts index 9fb5206551b..6da74005d91 100644 --- a/packages/opencode/src/server/server.ts +++ b/packages/opencode/src/server/server.ts @@ -130,6 +130,63 @@ export namespace Server { }), ) .route("/global", GlobalRoutes()) + .get( + "/auth", + describeRoute({ + summary: "List all auth accounts", + description: "Get all providers and their accounts", + operationId: "auth.list", + responses: { + 200: { + description: "All auth accounts", + content: { + "application/json": { + schema: resolver( + z.record( + z.string(), + z.object({ + accounts: z.record(z.string(), Auth.Info), + activeAccount: z.string().optional(), + }), + ), + ), + }, + }, + }, + }, + }), + async (c) => { + return c.json(await Auth.all()) + }, + ) + .get( + "/auth/:providerID", + describeRoute({ + summary: "Get provider accounts", + description: "Get all accounts for a specific provider", + operationId: "auth.getAccounts", + responses: { + 200: { + description: "Provider accounts", + content: { + "application/json": { + schema: resolver(z.record(z.string(), Auth.Info)), + }, + }, + }, + }, + }), + validator( + "param", + z.object({ + providerID: z.string(), + }), + ), + async (c) => { + const providerID = c.req.valid("param").providerID + return c.json(await Auth.getAccounts(providerID)) + }, + ) .put( "/auth/:providerID", describeRoute({ @@ -192,6 +249,43 @@ export namespace Server { return c.json(true) }, ) + .post( + "/auth/:providerID/use", + describeRoute({ + summary: "Switch active account", + description: "Switch the active account for a provider", + operationId: "auth.use", + responses: { + 200: { + description: "Successfully switched account", + content: { + "application/json": { + schema: resolver(z.boolean()), + }, + }, + }, + ...errors(400), + }, + }), + validator( + "param", + z.object({ + providerID: z.string(), + }), + ), + validator( + "json", + z.object({ + account: z.string(), + }), + ), + async (c) => { + const providerID = c.req.valid("param").providerID + const { account } = c.req.valid("json") + await Auth.use(providerID, account) + return c.json(true) + }, + ) .use(async (c, next) => { if (c.req.path === "/log") return next() const raw = c.req.query("directory") || c.req.header("x-opencode-directory") || process.cwd() diff --git a/packages/opencode/test/auth/auth.test.ts b/packages/opencode/test/auth/auth.test.ts new file mode 100644 index 00000000000..f862da7f229 --- /dev/null +++ b/packages/opencode/test/auth/auth.test.ts @@ -0,0 +1,396 @@ +import { describe, expect, test, beforeEach, afterEach } from "bun:test" +import { Auth } from "../../src/auth" +import type { Auth as AuthType } from "../../src/auth" +import { tmpdir } from "../fixture/fixture" +import { Instance } from "../../src/project/instance" +import { Global } from "../../src/global" +import path from "path" +import fs from "fs/promises" + +function asApiAuth(info: AuthType.Info | undefined): { key: string } | undefined { + return info as any +} + +describe("auth multi-account", () => { + let testAuthPath: string + + beforeEach(async () => { + testAuthPath = path.join(Global.Path.data, "auth.json") + await fs.rm(testAuthPath, { force: true }).catch(() => {}) + }) + + afterEach(async () => { + await fs.rm(testAuthPath, { force: true }).catch(() => {}) + }) + + test("add creates first account as active", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const accountId = await Auth.add("openai", { + type: "api", + key: "sk-test-key", + }) + + const accounts = await Auth.list("openai") + expect(accounts).toContain(accountId) + expect(await Auth.getActiveAccount("openai")).toBe(accountId) + }, + }) + }) + + test("add second account does not change active account", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-1", + }, + "default", + ) + + const firstActive = await Auth.getActiveAccount("openai") + + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-2", + }, + "work", + ) + + expect(await Auth.getActiveAccount("openai")).toBe(firstActive) + }, + }) + }) + + test("use changes active account", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-1", + }, + "default", + ) + + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-2", + }, + "work", + ) + + await Auth.use("openai", "work") + expect(await Auth.getActiveAccount("openai")).toBe("work") + + const creds = asApiAuth(await Auth.get("openai")) + expect(creds?.key).toBe("sk-key-2") + }, + }) + }) + + test("remove non-active account preserves active", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-1", + }, + "default", + ) + + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-2", + }, + "work", + ) + + await Auth.remove("openai", "work") + expect(await Auth.getActiveAccount("openai")).toBe("default") + }, + }) + }) + + test("remove active account promotes another", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-1", + }, + "default", + ) + + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-2", + }, + "work", + ) + + await Auth.use("openai", "work") + await Auth.remove("openai", "work") + + const active = await Auth.getActiveAccount("openai") + expect(active).toBeDefined() + expect(active).toBe("default") + }, + }) + }) + + test("remove last account removes provider", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-1", + }, + "default", + ) + + await Auth.remove("openai", "default") + + const accounts = await Auth.list("openai") + expect(accounts).toHaveLength(0) + }, + }) + }) + + test("get returns credentials for active account", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-1", + }, + "default", + ) + + const creds = asApiAuth(await Auth.get("openai")) + expect(creds?.key).toBe("sk-key-1") + }, + }) + }) + + test("get accepts explicit account parameter", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-1", + }, + "default", + ) + + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-2", + }, + "work", + ) + + await Auth.use("openai", "default") + + const creds = asApiAuth(await Auth.get("openai", "work")) + expect(creds?.key).toBe("sk-key-2") + }, + }) + }) + + test("setEnabled can disable and enable accounts", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-1", + }, + "default", + ) + + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-2", + }, + "work", + ) + + await Auth.setEnabled("openai", "work", false) + + const creds = asApiAuth(await Auth.get("openai")) + expect(creds?.key).toBe("sk-key-1") + + await Auth.setEnabled("openai", "work", true) + await Auth.use("openai", "work") + + const creds2 = asApiAuth(await Auth.get("openai")) + expect(creds2?.key).toBe("sk-key-2") + }, + }) + }) + + test("getAccounts returns all accounts for provider", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-1", + }, + "default", + ) + + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-2", + }, + "work", + ) + + const accounts = await Auth.getAccounts("openai") + expect(Object.keys(accounts)).toHaveLength(2) + }, + }) + }) + + test("set with account parameter stores credentials under that account", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await Auth.set( + "openai", + { + type: "api", + key: "sk-key-1", + }, + "work", + ) + + const creds = asApiAuth(await Auth.get("openai", "work")) + expect(creds?.key).toBe("sk-key-1") + expect(await Auth.getActiveAccount("openai")).toBe("work") + }, + }) + }) + + test("all returns all providers and accounts", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await Auth.set( + "openai", + { + type: "api", + key: "sk-openai", + }, + "default", + ) + + await Auth.set( + "anthropic", + { + type: "api", + key: "sk-anthropic", + }, + "default", + ) + + const all = await Auth.all() + expect(all.openai).toBeDefined() + expect(all.anthropic).toBeDefined() + expect(Object.keys(all.openai.accounts)).toHaveLength(1) + expect(Object.keys(all.anthropic.accounts)).toHaveLength(1) + }, + }) + }) +}) + +describe("auth legacy migration", () => { + let testAuthPath: string + + beforeEach(async () => { + testAuthPath = path.join(Global.Path.data, "auth.json") + await fs.rm(testAuthPath, { force: true }).catch(() => {}) + }) + + afterEach(async () => { + await fs.rm(testAuthPath, { force: true }).catch(() => {}) + }) + + test("migrates legacy format on read", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await Bun.write( + testAuthPath, + JSON.stringify({ + openai: { + type: "api", + key: "sk-legacy-key", + }, + }), + ) + + const creds = asApiAuth(await Auth.get("openai")) + expect(creds?.key).toBe("sk-legacy-key") + + const all = await Auth.all() + expect(all.openai.accounts.default).toBeDefined() + expect(all.openai.activeAccount).toBe("default") + }, + }) + }) +}) diff --git a/packages/opencode/test/config/config.test.ts b/packages/opencode/test/config/config.test.ts index 91b87f6498c..616bf8936e3 100644 --- a/packages/opencode/test/config/config.test.ts +++ b/packages/opencode/test/config/config.test.ts @@ -1471,9 +1471,14 @@ test("project config overrides remote well-known config", async () => { Auth.all = mock(() => Promise.resolve({ "https://example.com": { - type: "wellknown" as const, - key: "TEST_TOKEN", - token: "test-token", + accounts: { + default: { + type: "wellknown" as const, + key: "TEST_TOKEN", + token: "test-token", + }, + }, + activeAccount: "default", }, }), ) diff --git a/packages/sdk/js/src/v2/gen/sdk.gen.ts b/packages/sdk/js/src/v2/gen/sdk.gen.ts index af79c44a17a..310edddd37e 100644 --- a/packages/sdk/js/src/v2/gen/sdk.gen.ts +++ b/packages/sdk/js/src/v2/gen/sdk.gen.ts @@ -9,10 +9,14 @@ import type { AppLogResponses, AppSkillsResponses, Auth as Auth3, + AuthGetAccountsResponses, + AuthListResponses, AuthRemoveErrors, AuthRemoveResponses, AuthSetErrors, AuthSetResponses, + AuthUseErrors, + AuthUseResponses, CommandListResponses, Config as Config3, ConfigGetResponses, @@ -301,6 +305,15 @@ export class Global extends HeyApiClient { } export class Auth extends HeyApiClient { + /** + * List all auth accounts + * + * Get all providers and their accounts + */ + public list(options?: Options) { + return (options?.client ?? this.client).get({ url: "/auth", ...options }) + } + /** * Remove auth credentials * @@ -320,6 +333,25 @@ export class Auth extends HeyApiClient { }) } + /** + * Get provider accounts + * + * Get all accounts for a specific provider + */ + public getAccounts( + parameters: { + providerID: string + }, + options?: Options, + ) { + const params = buildClientParams([parameters], [{ args: [{ in: "path", key: "providerID" }] }]) + return (options?.client ?? this.client).get({ + url: "/auth/{providerID}", + ...options, + ...params, + }) + } + /** * Set auth credentials * @@ -354,6 +386,41 @@ export class Auth extends HeyApiClient { }, }) } + + /** + * Switch active account + * + * Switch the active account for a provider + */ + public use( + parameters: { + providerID: string + account?: string + }, + options?: Options, + ) { + const params = buildClientParams( + [parameters], + [ + { + args: [ + { in: "path", key: "providerID" }, + { in: "body", key: "account" }, + ], + }, + ], + ) + return (options?.client ?? this.client).post({ + url: "/auth/{providerID}/use", + ...options, + ...params, + headers: { + "Content-Type": "application/json", + ...options?.headers, + ...params.headers, + }, + }) + } } export class Project extends HeyApiClient { diff --git a/packages/sdk/js/src/v2/gen/types.gen.ts b/packages/sdk/js/src/v2/gen/types.gen.ts index b22b7e9af4e..8d8bdc82afa 100644 --- a/packages/sdk/js/src/v2/gen/types.gen.ts +++ b/packages/sdk/js/src/v2/gen/types.gen.ts @@ -1430,7 +1430,7 @@ export type PermissionConfig = | PermissionActionConfig export type AgentConfig = { - model?: string + model?: Model /** * Default model variant for this agent (applies only when using the agent's configured model). */ @@ -1472,6 +1472,7 @@ export type AgentConfig = { permission?: PermissionConfig [key: string]: | unknown + | Model | string | number | { @@ -1702,7 +1703,7 @@ export type Config = { template: string description?: string agent?: string - model?: string + model?: Model subtask?: boolean } } @@ -1744,14 +1745,8 @@ export type Config = { * When set, ONLY these providers will be enabled. All other providers will be ignored */ enabled_providers?: Array - /** - * Model to use in the format of provider/model, eg anthropic/claude-2 - */ - model?: string - /** - * Small model to use for tasks like title generation in the format of provider/model - */ - small_model?: string + model?: Model + small_model?: Model /** * Default agent to use when none is specified. Must be a primary agent. Falls back to 'build' if not set or if the specified agent is invalid. */ @@ -1787,6 +1782,12 @@ export type Config = { provider?: { [key: string]: ProviderConfig } + /** + * Account to use per provider. Use provider ID as key and account name as value (e.g., { "openai": "work", "anthropic": "personal" }) + */ + auth?: { + [key: string]: string + } /** * MCP (Model Context Protocol) server configurations */ @@ -1898,11 +1899,13 @@ export type OAuth = { expires: number accountId?: string enterpriseUrl?: string + email?: string } export type ApiAuth = { type: "api" key: string + email?: string } export type WellKnownAuth = { @@ -2329,6 +2332,29 @@ export type GlobalDisposeResponses = { export type GlobalDisposeResponse = GlobalDisposeResponses[keyof GlobalDisposeResponses] +export type AuthListData = { + body?: never + path?: never + query?: never + url: "/auth" +} + +export type AuthListResponses = { + /** + * All auth accounts + */ + 200: { + [key: string]: { + accounts: { + [key: string]: Auth + } + activeAccount?: string + } + } +} + +export type AuthListResponse = AuthListResponses[keyof AuthListResponses] + export type AuthRemoveData = { body?: never path: { @@ -2356,6 +2382,26 @@ export type AuthRemoveResponses = { export type AuthRemoveResponse = AuthRemoveResponses[keyof AuthRemoveResponses] +export type AuthGetAccountsData = { + body?: never + path: { + providerID: string + } + query?: never + url: "/auth/{providerID}" +} + +export type AuthGetAccountsResponses = { + /** + * Provider accounts + */ + 200: { + [key: string]: Auth + } +} + +export type AuthGetAccountsResponse = AuthGetAccountsResponses[keyof AuthGetAccountsResponses] + export type AuthSetData = { body?: Auth path: { @@ -2383,6 +2429,35 @@ export type AuthSetResponses = { export type AuthSetResponse = AuthSetResponses[keyof AuthSetResponses] +export type AuthUseData = { + body?: { + account: string + } + path: { + providerID: string + } + query?: never + url: "/auth/{providerID}/use" +} + +export type AuthUseErrors = { + /** + * Bad request + */ + 400: BadRequestError +} + +export type AuthUseError = AuthUseErrors[keyof AuthUseErrors] + +export type AuthUseResponses = { + /** + * Successfully switched account + */ + 200: boolean +} + +export type AuthUseResponse = AuthUseResponses[keyof AuthUseResponses] + export type ProjectListData = { body?: never path?: never diff --git a/packages/ui/src/components/diff-ssr.tsx b/packages/ui/src/components/diff-ssr.tsx index e739afc16d8..9cb132698ed 100644 --- a/packages/ui/src/components/diff-ssr.tsx +++ b/packages/ui/src/components/diff-ssr.tsx @@ -1,9 +1,8 @@ -import { DIFFS_TAG_NAME, FileDiff, type SelectedLineRange, VirtualizedFileDiff } from "@pierre/diffs" +import { DIFFS_TAG_NAME, FileDiff, type SelectedLineRange } from "@pierre/diffs" import { PreloadMultiFileDiffResult } from "@pierre/diffs/ssr" import { createEffect, onCleanup, onMount, Show, splitProps } from "solid-js" import { Dynamic, isServer } from "solid-js/web" import { createDefaultOptions, styleVariables, type DiffProps } from "../pierre" -import { acquireVirtualizer, virtualMetrics } from "../pierre/virtualizer" import { useWorkerPool } from "../context/worker-pool" export type SSRDiffProps = DiffProps & { @@ -25,21 +24,10 @@ export function Diff(props: SSRDiffProps) { const workerPool = useWorkerPool(props.diffStyle) let fileDiffInstance: FileDiff | undefined - let sharedVirtualizer: NonNullable> | undefined const cleanupFunctions: Array<() => void> = [] const getRoot = () => fileDiffRef?.shadowRoot ?? undefined - const getVirtualizer = () => { - if (sharedVirtualizer) return sharedVirtualizer.virtualizer - - const result = acquireVirtualizer(container) - if (!result) return - - sharedVirtualizer = result - return result.virtualizer - } - const applyScheme = () => { const scheme = document.documentElement.dataset.colorScheme if (scheme === "dark" || scheme === "light") { @@ -227,27 +215,14 @@ export function Diff(props: SSRDiffProps) { onCleanup(() => monitor.disconnect()) } - const virtualizer = getVirtualizer() - - fileDiffInstance = virtualizer - ? new VirtualizedFileDiff( - { - ...createDefaultOptions(props.diffStyle), - ...others, - ...props.preloadedDiff, - }, - virtualizer, - virtualMetrics, - workerPool, - ) - : new FileDiff( - { - ...createDefaultOptions(props.diffStyle), - ...others, - ...props.preloadedDiff, - }, - workerPool, - ) + fileDiffInstance = new FileDiff( + { + ...createDefaultOptions(props.diffStyle), + ...others, + ...props.preloadedDiff, + }, + workerPool, + ) // @ts-expect-error - fileContainer is private but needed for SSR hydration fileDiffInstance.fileContainer = fileDiffRef fileDiffInstance.hydrate({ @@ -301,8 +276,6 @@ export function Diff(props: SSRDiffProps) { // Clean up FileDiff event handlers and dispose SolidJS components fileDiffInstance?.cleanUp() cleanupFunctions.forEach((dispose) => dispose()) - sharedVirtualizer?.release() - sharedVirtualizer = undefined }) return ( diff --git a/packages/ui/src/components/diff.tsx b/packages/ui/src/components/diff.tsx index 0966db75e03..8e4fd64fbe4 100644 --- a/packages/ui/src/components/diff.tsx +++ b/packages/ui/src/components/diff.tsx @@ -1,9 +1,8 @@ import { checksum } from "@opencode-ai/util/encode" -import { FileDiff, type SelectedLineRange, VirtualizedFileDiff } from "@pierre/diffs" +import { FileDiff, type SelectedLineRange } from "@pierre/diffs" import { createMediaQuery } from "@solid-primitives/media" import { createEffect, createMemo, createSignal, onCleanup, splitProps } from "solid-js" import { createDefaultOptions, type DiffProps, styleVariables } from "../pierre" -import { acquireVirtualizer, virtualMetrics } from "../pierre/virtualizer" import { getWorkerPool } from "../pierre/worker" type SelectionSide = "additions" | "deletions" @@ -53,7 +52,6 @@ function findSide(node: Node | null): SelectionSide | undefined { export function Diff(props: DiffProps) { let container!: HTMLDivElement let observer: MutationObserver | undefined - let sharedVirtualizer: NonNullable> | undefined let renderToken = 0 let selectionFrame: number | undefined let dragFrame: number | undefined @@ -94,16 +92,6 @@ export function Diff(props: DiffProps) { const [current, setCurrent] = createSignal | undefined>(undefined) const [rendered, setRendered] = createSignal(0) - const getVirtualizer = () => { - if (sharedVirtualizer) return sharedVirtualizer.virtualizer - - const result = acquireVirtualizer(container) - if (!result) return - - sharedVirtualizer = result - return result.virtualizer - } - const getRoot = () => { const host = container.querySelector("diffs-container") if (!(host instanceof HTMLElement)) return @@ -529,15 +517,12 @@ export function Diff(props: DiffProps) { createEffect(() => { const opts = options() const workerPool = getWorkerPool(props.diffStyle) - const virtualizer = getVirtualizer() const annotations = local.annotations const beforeContents = typeof local.before?.contents === "string" ? local.before.contents : "" const afterContents = typeof local.after?.contents === "string" ? local.after.contents : "" instance?.cleanUp() - instance = virtualizer - ? new VirtualizedFileDiff(opts, virtualizer, virtualMetrics, workerPool) - : new FileDiff(opts, workerPool) + instance = new FileDiff(opts, workerPool) setCurrent(instance) container.innerHTML = "" @@ -624,8 +609,6 @@ export function Diff(props: DiffProps) { instance?.cleanUp() setCurrent(undefined) - sharedVirtualizer?.release() - sharedVirtualizer = undefined }) return
diff --git a/packages/ui/src/pierre/index.ts b/packages/ui/src/pierre/index.ts index dc9d857bf87..e47433ce301 100644 --- a/packages/ui/src/pierre/index.ts +++ b/packages/ui/src/pierre/index.ts @@ -136,7 +136,7 @@ export function createDefaultOptions(style: FileDiffOptions["diffStyle"]) lineHoverHighlight: "both", disableBackground: false, expansionLineCount: 20, - hunkSeparators: "line-info-basic", + hunkSeparators: "line-info", lineDiffType: style === "split" ? "word-alt" : "none", maxLineDiffLength: 1000, maxLineLengthForHighlighting: 1000, diff --git a/packages/ui/src/pierre/virtualizer.ts b/packages/ui/src/pierre/virtualizer.ts index 4957afc1255..8b3d4ea077a 100644 --- a/packages/ui/src/pierre/virtualizer.ts +++ b/packages/ui/src/pierre/virtualizer.ts @@ -1,76 +1,15 @@ -import { type VirtualFileMetrics, Virtualizer } from "@pierre/diffs" - -type Target = { - key: Document | HTMLElement - root: Document | HTMLElement - content: HTMLElement | undefined -} - -type Entry = { - virtualizer: Virtualizer - refs: number +type VirtualMetrics = { + lineHeight: number + hunkSeparatorHeight: number + fileGap: number } -const cache = new WeakMap() - -export const virtualMetrics: Partial = { +export const virtualMetrics: Partial = { lineHeight: 24, hunkSeparatorHeight: 24, fileGap: 0, } -function target(container: HTMLElement): Target | undefined { - if (typeof document === "undefined") return - - const root = container.closest("[data-component='session-review']") - if (root instanceof HTMLElement) { - const content = root.querySelector("[data-slot='session-review-container']") - return { - key: root, - root, - content: content instanceof HTMLElement ? content : undefined, - } - } - - return { - key: document, - root: document, - content: undefined, - } -} - -export function acquireVirtualizer(container: HTMLElement) { - const resolved = target(container) - if (!resolved) return - - let entry = cache.get(resolved.key) - if (!entry) { - const virtualizer = new Virtualizer() - virtualizer.setup(resolved.root, resolved.content) - entry = { - virtualizer, - refs: 0, - } - cache.set(resolved.key, entry) - } - - entry.refs += 1 - let done = false - - return { - virtualizer: entry.virtualizer, - release() { - if (done) return - done = true - - const current = cache.get(resolved.key) - if (!current) return - - current.refs -= 1 - if (current.refs > 0) return - - current.virtualizer.cleanUp() - cache.delete(resolved.key) - }, - } +export function acquireVirtualizer(_container: HTMLElement) { + return } diff --git a/packages/ui/src/pierre/worker.ts b/packages/ui/src/pierre/worker.ts index 1993ad7aa6f..0d117c3683f 100644 --- a/packages/ui/src/pierre/worker.ts +++ b/packages/ui/src/pierre/worker.ts @@ -21,7 +21,6 @@ function createPool(lineDiffType: "none" | "word-alt") { { theme: "OpenCode", lineDiffType, - preferredHighlighter: "shiki-wasm", }, )