diff --git a/packages/db/src/schema-types.ts b/packages/db/src/schema-types.ts index 6a63423a44..4bacb6abd9 100644 --- a/packages/db/src/schema-types.ts +++ b/packages/db/src/schema-types.ts @@ -127,6 +127,11 @@ export type OrganizationPlan = z.infer; const OrganizationSettingsSchema = z.object({ model_allow_list: z.array(z.string()).optional(), provider_allow_list: z.array(z.string()).optional(), + + // under development, not yet enforced, will replace model_allow_list and provider_allow_list: + model_deny_list: z.array(z.string()).optional(), + provider_deny_list: z.array(z.string()).optional(), + default_model: z.string().optional(), data_collection: z.enum(['allow', 'deny']).nullable().optional(), // null means they were grandfathered in and so they have usage limits enabled diff --git a/src/lib/model-allow.server.ts b/src/lib/model-allow.server.ts index 82b9985483..1f7ede9b5c 100644 --- a/src/lib/model-allow.server.ts +++ b/src/lib/model-allow.server.ts @@ -1,7 +1,10 @@ import 'server-only'; import { normalizeModelId } from '@/lib/model-utils'; -import { getProviderSlugsForModel } from '@/lib/providers/openrouter/models-by-provider-index.server'; +import { + fetchLatestModelsByProviderSnapshotFromDb, + getProviderSlugsForModel, +} from '@/lib/providers/openrouter/models-by-provider-index.server'; import { isAllowedByExactOrNamespaceWildcard, isAllowedByProviderMembershipWildcard, @@ -44,3 +47,32 @@ export function createProviderAwareModelAllowPredicate( return isAllowedByProviderMembershipWildcard(providersForModel, wildcardProviderSlugs); }; } + +export async function createDenyLists( + model_allow_list: string[] | undefined, + provider_allow_list: string[] | undefined +) { + if (!model_allow_list && !provider_allow_list) { + return undefined; + } + const data = await fetchLatestModelsByProviderSnapshotFromDb(); + if (!data) { + return undefined; + } + const isAllowed = model_allow_list + ? createProviderAwareModelAllowPredicate(model_allow_list) + : undefined; + const model_deny_list = new Set(); + const provider_deny_list = new Set(); + for (const provider of data.providers) { + if (provider_allow_list && !provider_allow_list.includes(provider.slug)) { + provider_deny_list.add(provider.slug); + } + for (const model of provider.models) { + if (isAllowed && !(await isAllowed(model.slug))) { + model_deny_list.add(model.slug); + } + } + } + return { model_deny_list: [...model_deny_list], provider_deny_list: [...provider_deny_list] }; +} diff --git a/src/lib/providers/openrouter/models-by-provider-index.server.ts b/src/lib/providers/openrouter/models-by-provider-index.server.ts index 3fb523fb4e..153628461c 100644 --- a/src/lib/providers/openrouter/models-by-provider-index.server.ts +++ b/src/lib/providers/openrouter/models-by-provider-index.server.ts @@ -85,7 +85,7 @@ export function createModelsByProviderIndexLoader(options: ProviderIndexLoaderOp }; } -async function fetchLatestModelsByProviderSnapshotFromDb(): Promise< +export async function fetchLatestModelsByProviderSnapshotFromDb(): Promise< NormalizedOpenRouterResponse | undefined > { const result = await db diff --git a/src/routers/organizations/organization-settings-router.ts b/src/routers/organizations/organization-settings-router.ts index fc6590bd47..dc8f09e81f 100644 --- a/src/routers/organizations/organization-settings-router.ts +++ b/src/routers/organizations/organization-settings-router.ts @@ -15,7 +15,7 @@ import * as z from 'zod'; import { createAuditLog } from '@/lib/organizations/organization-audit-logs'; import { getEnhancedOpenRouterModels } from '@/lib/providers/openrouter'; import { requireActiveSubscriptionOrTrial } from '@/lib/organizations/trial-middleware'; -import { createProviderAwareModelAllowPredicate } from '@/lib/model-allow.server'; +import { createDenyLists, createProviderAwareModelAllowPredicate } from '@/lib/model-allow.server'; import { KILO_ORGANIZATION_ID } from '@/lib/organizations/constants'; import { listAvailableCustomLlms } from '@/lib/custom-llm/listAvailableCustomLlms'; @@ -223,6 +223,15 @@ export const organizationsSettingsRouter = createTRPCRouter({ settingsUpdate.provider_allow_list = [...new Set(provider_allow_list)]; // Deduplicate slugs } + const denyLists = await createDenyLists( + settingsUpdate.model_allow_list, + settingsUpdate.provider_allow_list + ); + if (denyLists) { + settingsUpdate.model_deny_list = denyLists.model_deny_list; + settingsUpdate.provider_deny_list = denyLists.provider_deny_list; + } + // Check if default_model needs to be cleared if ( model_allow_list !== undefined &&